/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use serde_json; use std::marker::PhantomData; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; use bytes::Bytes; use http::{self, Method, StatusCode}; use tokio::net::TcpListener; use url::{Host, Url}; use warp::{self, Buf, Filter, Rejection}; use crate::command::{WebDriverCommand, WebDriverMessage}; use crate::error::{ErrorStatus, WebDriverError, WebDriverResult}; use crate::httpapi::{ standard_routes, Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute, }; use crate::response::{CloseWindowResponse, WebDriverResponse}; use crate::Parameters; // Silence warning about Quit being unused for now. #[allow(dead_code)] enum DispatchMessage { HandleWebDriver( WebDriverMessage, Sender>, ), Quit, } #[derive(Clone, Debug, PartialEq)] pub struct Session { pub id: String, } impl Session { fn new(id: String) -> Session { Session { id } } } pub trait WebDriverHandler: Send { fn handle_command( &mut self, session: &Option, msg: WebDriverMessage, ) -> WebDriverResult; fn delete_session(&mut self, session: &Option); } #[derive(Debug)] struct Dispatcher, U: WebDriverExtensionRoute> { handler: T, session: Option, extension_type: PhantomData, } impl, U: WebDriverExtensionRoute> Dispatcher { fn new(handler: T) -> Dispatcher { Dispatcher { handler, session: None, extension_type: PhantomData, } } fn run(&mut self, msg_chan: &Receiver>) { loop { match msg_chan.recv() { Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => { let resp = match self.check_session(&msg) { Ok(_) => self.handler.handle_command(&self.session, msg), Err(e) => Err(e), }; match resp { Ok(WebDriverResponse::NewSession(ref new_session)) => { self.session = Some(Session::new(new_session.session_id.clone())); } Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => { if handles.is_empty() { debug!("Last window was closed, deleting session"); self.delete_session(); } } Ok(WebDriverResponse::DeleteSession) => self.delete_session(), Err(ref x) if x.delete_session => self.delete_session(), _ => {} } if resp_chan.send(resp).is_err() { error!("Sending response to the main thread failed"); }; } Ok(DispatchMessage::Quit) => break, Err(e) => panic!("Error receiving message in handler: {:?}", e), } } } fn delete_session(&mut self) { debug!("Deleting session"); self.handler.delete_session(&self.session); self.session = None; } fn check_session(&self, msg: &WebDriverMessage) -> WebDriverResult<()> { match msg.session_id { Some(ref msg_session_id) => match self.session { Some(ref existing_session) => { if existing_session.id != *msg_session_id { Err(WebDriverError::new( ErrorStatus::InvalidSessionId, format!("Got unexpected session id {}", msg_session_id), )) } else { Ok(()) } } None => Ok(()), }, None => { match self.session { Some(_) => { match msg.command { WebDriverCommand::Status => Ok(()), WebDriverCommand::NewSession(_) => Err(WebDriverError::new( ErrorStatus::SessionNotCreated, "Session is already started", )), _ => { //This should be impossible error!("Got a message with no session id"); Err(WebDriverError::new( ErrorStatus::UnknownError, "Got a command with no session?!", )) } } } None => match msg.command { WebDriverCommand::NewSession(_) => Ok(()), WebDriverCommand::Status => Ok(()), _ => Err(WebDriverError::new( ErrorStatus::InvalidSessionId, "Tried to run a command before creating a session", )), }, } } } } } pub struct Listener { guard: Option>, pub socket: SocketAddr, } impl Drop for Listener { fn drop(&mut self) { let _ = self.guard.take().map(|j| j.join()); } } pub fn start( address: SocketAddr, handler: T, extension_routes: Vec<(Method, &'static str, U)>, ) -> ::std::io::Result where T: 'static + WebDriverHandler, U: 'static + WebDriverExtensionRoute + Send + Sync, { let listener = StdTcpListener::bind(address)?; let addr = listener.local_addr()?; let (msg_send, msg_recv) = channel(); let builder = thread::Builder::new().name("webdriver server".to_string()); let handle = builder.spawn(move || { let mut rt = tokio::runtime::Builder::new() .basic_scheduler() .enable_io() .build() .unwrap(); let mut listener = rt .handle() .enter(|| TcpListener::from_std(listener).unwrap()); let wroutes = build_warp_routes(address, &extension_routes, msg_send.clone()); let fut = warp::serve(wroutes).run_incoming(listener.incoming()); rt.block_on(fut); })?; let builder = thread::Builder::new().name("webdriver dispatcher".to_string()); builder.spawn(move || { let mut dispatcher = Dispatcher::new(handler); dispatcher.run(&msg_recv); })?; Ok(Listener { guard: Some(handle), socket: addr, }) } fn build_warp_routes( address: SocketAddr, ext_routes: &[(Method, &'static str, U)], chan: Sender>, ) -> impl Filter + Clone { let chan = Arc::new(Mutex::new(chan)); let mut std_routes = standard_routes::(); let (method, path, res) = std_routes.pop().unwrap(); let mut wroutes = build_route(address, method, path, res, chan.clone()); for (method, path, res) in std_routes { wroutes = wroutes .or(build_route( address, method, path, res.clone(), chan.clone(), )) .unify() .boxed() } for (method, path, res) in ext_routes { wroutes = wroutes .or(build_route( address, method.clone(), path, Route::Extension(res.clone()), chan.clone(), )) .unify() .boxed() } wroutes } fn build_route( address: SocketAddr, method: Method, path: &'static str, route: Route, chan: Arc>>>, ) -> warp::filters::BoxedFilter<(impl warp::Reply,)> { // Create an empty filter based on the provided method and append an empty hashmap to it. The // hashmap will be used to store path parameters. let mut subroute = match method { Method::GET => warp::get().boxed(), Method::POST => warp::post().boxed(), Method::DELETE => warp::delete().boxed(), Method::OPTIONS => warp::options().boxed(), Method::PUT => warp::put().boxed(), _ => panic!("Unsupported method"), } .or(warp::head()) .unify() .map(Parameters::new) .boxed(); // For each part of the path, if it's a normal part, just append it to the current filter, // otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert // it into the hashmap created earlier. for part in path.split('/') { if part.is_empty() { continue; } else if part.starts_with('{') { assert!(part.ends_with('}')); subroute = subroute .and(warp::path::param()) .map(move |mut params: Parameters, param: String| { let name = &part[1..part.len() - 1]; params.insert(name.to_string(), param); params }) .boxed(); } else { subroute = subroute.and(warp::path(part)).boxed(); } } // Finally, tell warp that the path is complete subroute .and(warp::path::end()) .and(warp::path::full()) .and(warp::method()) .and(warp::header::optional::("origin")) .and(warp::header::optional::("content-type")) .and(warp::body::bytes()) .map( move |params, full_path: warp::path::FullPath, method, origin_header: Option, content_type_header: Option, body: Bytes| { if method == Method::HEAD { return warp::reply::with_status("".into(), StatusCode::OK); } if let Some(origin) = origin_header { let mut valid_host = false; let host_url = Url::parse(&origin).ok(); let host = host_url.as_ref().and_then(|x| x.host().to_owned()); if let Some(host) = host { valid_host = match host { Host::Domain("localhost") => true, Host::Domain(_) => false, Host::Ipv4(x) => address.is_ipv4() && x == address.ip(), Host::Ipv6(x) => address.is_ipv6() && x == address.ip(), }; } if !valid_host { let err = WebDriverError::new(ErrorStatus::UnknownError, "Invalid Origin"); return warp::reply::with_status( serde_json::to_string(&err).unwrap(), StatusCode::INTERNAL_SERVER_ERROR, ); } } if method == Method::POST { // Disallow CORS-safelisted request headers // c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header let content_type = content_type_header .as_ref() .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x)) .map(|x| x.trim()) .map(|x| x.to_lowercase()); match content_type.as_ref().map(|x| x.as_ref()) { Some("application/x-www-form-urlencoded") | Some("multipart/form-data") | Some("text/plain") => { let err = WebDriverError::new( ErrorStatus::UnknownError, "Invalid Content-Type", ); return warp::reply::with_status( serde_json::to_string(&err).unwrap(), StatusCode::INTERNAL_SERVER_ERROR, ); } Some(_) | None => {} } } let body = String::from_utf8(body.bytes().to_vec()); if body.is_err() { let err = WebDriverError::new( ErrorStatus::UnknownError, "Request body wasn't valid UTF-8", ); return warp::reply::with_status( serde_json::to_string(&err).unwrap(), StatusCode::INTERNAL_SERVER_ERROR, ); } let body = body.unwrap(); debug!("-> {} {} {}", method, full_path.as_str(), body); let msg_result = WebDriverMessage::from_http( route.clone(), ¶ms, &body, method == Method::POST, ); let (status, resp_body) = match msg_result { Ok(message) => { let (send_res, recv_res) = channel(); match chan.lock() { Ok(ref c) => { let res = c.send(DispatchMessage::HandleWebDriver(message, send_res)); match res { Ok(x) => x, Err(e) => panic!("Error: {:?}", e), } } Err(e) => panic!("Error reading response: {:?}", e), } match recv_res.recv() { Ok(data) => match data { Ok(response) => { (StatusCode::OK, serde_json::to_string(&response).unwrap()) } Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()), }, Err(e) => panic!("Error reading response: {:?}", e), } } Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()), }; debug!("<- {} {}", status, resp_body); warp::reply::with_status(resp_body, status) }, ) .with(warp::reply::with::header( http::header::CONTENT_TYPE, "application/json; charset=utf-8", )) .with(warp::reply::with::header( http::header::CACHE_CONTROL, "no-cache", )) .boxed() }