diff --git a/actix-framed/src/lib.rs b/actix-framed/src/lib.rs index a67b82e43..62b0cf1ec 100644 --- a/actix-framed/src/lib.rs +++ b/actix-framed/src/lib.rs @@ -2,6 +2,7 @@ mod app; mod helpers; mod request; mod route; +mod service; mod state; // re-export for convinience @@ -10,4 +11,5 @@ pub use actix_http::{http, Error, HttpMessage, Response, ResponseError}; pub use self::app::{FramedApp, FramedAppService}; pub use self::request::FramedRequest; pub use self::route::FramedRoute; +pub use self::service::{SendError, VerifyWebSockets}; pub use self::state::State; diff --git a/actix-framed/src/service.rs b/actix-framed/src/service.rs new file mode 100644 index 000000000..bc730074c --- /dev/null +++ b/actix-framed/src/service.rs @@ -0,0 +1,112 @@ +use std::marker::PhantomData; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::error::{Error, ResponseError}; +use actix_http::ws::{verify_handshake, HandshakeError}; +use actix_http::{h1, Request}; +use actix_service::{NewService, Service}; +use futures::future::{ok, Either, FutureResult}; +use futures::{Async, Future, IntoFuture, Poll}; + +/// Service that verifies incoming request if it is valid websocket +/// upgrade request. In case of error returns `HandshakeError` +pub struct VerifyWebSockets { + _t: PhantomData, +} + +impl Default for VerifyWebSockets { + fn default() -> Self { + VerifyWebSockets { _t: PhantomData } + } +} + +impl NewService for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type InitError = (); + type Service = VerifyWebSockets; + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(VerifyWebSockets { _t: PhantomData }) + } +} + +impl Service for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { + match verify_handshake(req.head()) { + Err(e) => Err((e, framed)).into_future(), + Ok(_) => Ok((req, framed)).into_future(), + } + } +} + +/// Send http/1 error response +pub struct SendError(PhantomData<(T, R, E)>); + +impl Default for SendError +where + T: AsyncRead + AsyncWrite, + E: ResponseError, +{ + fn default() -> Self { + SendError(PhantomData) + } +} + +impl NewService for SendError +where + T: AsyncRead + AsyncWrite + 'static, + R: 'static, + E: ResponseError + 'static, +{ + type Request = Result)>; + type Response = R; + type Error = Error; + type InitError = (); + type Service = SendError; + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(SendError(PhantomData)) + } +} + +impl Service for SendError +where + T: AsyncRead + AsyncWrite + 'static, + R: 'static, + E: ResponseError + 'static, +{ + type Request = Result)>; + type Response = R; + type Error = Error; + type Future = Either, Box>>; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Result)>) -> Self::Future { + match req { + Ok(r) => Either::A(ok(r)), + Err((e, framed)) => { + let res = e.render_response(); + let e = Error::from(e); + Either::B(Box::new( + h1::SendResponse::new(framed, res).then(move |_| Err(e)), + )) + } + } + } +} diff --git a/actix-framed/tests/test_server.rs b/actix-framed/tests/test_server.rs index 09d2c03cc..5b85f3128 100644 --- a/actix-framed/tests/test_server.rs +++ b/actix-framed/tests/test_server.rs @@ -12,7 +12,7 @@ fn ws_service( req: FramedRequest, ) -> impl Future { let (req, framed, _) = req.into_parts(); - let res = ws::handshake(&req).unwrap().message_body(()); + let res = ws::handshake(req.head()).unwrap().message_body(()); framed .send((res, body::BodySize::None).into()) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index f5988afa4..ae9bb81c4 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -12,6 +12,8 @@ * MessageBody::length() renamed to MessageBody::size() for consistency +* ws handshake verification functions take RequestHead instead of Request + ## [0.1.0-alpha.4] - 2019-04-08 diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index a9ab320ea..b1aefdb1d 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -46,7 +46,7 @@ secure-cookies = ["ring"] [dependencies] actix-service = "0.3.6" actix-codec = "0.1.2" -actix-connect = "0.1.2" +actix-connect = "0.1.3" actix-utils = "0.3.5" actix-server-config = "0.1.0" actix-threadpool = "0.1.0" diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index fdc4cf0bc..c7b7ccec3 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -34,23 +34,35 @@ where B: MessageBody, { type Item = Framed; - type Error = Error; + type Error = (Error, Framed); fn poll(&mut self) -> Poll { loop { let mut body_ready = self.body.is_some(); - let framed = self.framed.as_mut().unwrap(); // send body if self.res.is_none() && self.body.is_some() { - while body_ready && self.body.is_some() && !framed.is_write_buf_full() { - match self.body.as_mut().unwrap().poll_next()? { + while body_ready + && self.body.is_some() + && !self.framed.as_ref().unwrap().is_write_buf_full() + { + match self + .body + .as_mut() + .unwrap() + .poll_next() + .map_err(|e| (e, self.framed.take().unwrap()))? + { Async::Ready(item) => { // body is done if item.is_none() { let _ = self.body.take(); } - framed.force_send(Message::Chunk(item))?; + self.framed + .as_mut() + .unwrap() + .force_send(Message::Chunk(item)) + .map_err(|e| (e.into(), self.framed.take().unwrap()))?; } Async::NotReady => body_ready = false, } @@ -58,8 +70,14 @@ where } // flush write buffer - if !framed.is_write_buf_empty() { - match framed.poll_complete()? { + if !self.framed.as_ref().unwrap().is_write_buf_empty() { + match self + .framed + .as_mut() + .unwrap() + .poll_complete() + .map_err(|e| (e.into(), self.framed.take().unwrap()))? + { Async::Ready(_) => { if body_ready { continue; @@ -73,7 +91,11 @@ where // send response if let Some(res) = self.res.take() { - framed.force_send(res)?; + self.framed + .as_mut() + .unwrap() + .force_send(res) + .map_err(|e| (e.into(), self.framed.take().unwrap()))?; continue; } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 065c34d93..891d5110d 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -9,21 +9,18 @@ use derive_more::{Display, From}; use http::{header, Method, StatusCode}; use crate::error::ResponseError; -use crate::httpmessage::HttpMessage; -use crate::request::Request; +use crate::message::RequestHead; use crate::response::{Response, ResponseBuilder}; mod codec; mod frame; mod mask; mod proto; -mod service; mod transport; pub use self::codec::{Codec, Frame, Message}; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; -pub use self::service::VerifyWebSockets; pub use self::transport::Transport; /// Websocket protocol errors @@ -112,7 +109,7 @@ impl ResponseError for HandshakeError { // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. -pub fn handshake(req: &Request) -> Result { +pub fn handshake(req: &RequestHead) -> Result { verify_handshake(req)?; Ok(handshake_response(req)) } @@ -121,9 +118,9 @@ pub fn handshake(req: &Request) -> Result { // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. -pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { +pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { // WebSocket accepts only GET - if *req.method() != Method::GET { + if req.method != Method::GET { return Err(HandshakeError::GetMethodRequired); } @@ -171,7 +168,7 @@ pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { /// Create websocket's handshake response /// /// This function returns handshake `Response`, ready to send to peer. -pub fn handshake_response(req: &Request) -> ResponseBuilder { +pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { let key = { let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); proto::hash_key(key.as_ref()) @@ -195,13 +192,13 @@ mod tests { let req = TestRequest::default().method(Method::POST).finish(); assert_eq!( HandshakeError::GetMethodRequired, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default().finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -209,7 +206,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -220,7 +217,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -235,7 +232,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoVersionHeader, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -254,7 +251,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::UnsupportedVersion, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -273,7 +270,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::BadWebsocketKey, - verify_handshake(&req).err().unwrap() + verify_handshake(req.head()).err().unwrap() ); let req = TestRequest::default() @@ -296,7 +293,7 @@ mod tests { .finish(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, - handshake_response(&req).finish().status() + handshake_response(req.head()).finish().status() ); } diff --git a/actix-http/src/ws/service.rs b/actix-http/src/ws/service.rs deleted file mode 100644 index f3b066053..000000000 --- a/actix-http/src/ws/service.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::marker::PhantomData; - -use actix_codec::Framed; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{Async, IntoFuture, Poll}; - -use crate::h1::Codec; -use crate::request::Request; - -use super::{verify_handshake, HandshakeError}; - -pub struct VerifyWebSockets { - _t: PhantomData, -} - -impl Default for VerifyWebSockets { - fn default() -> Self { - VerifyWebSockets { _t: PhantomData } - } -} - -impl NewService for VerifyWebSockets { - type Request = (Request, Framed); - type Response = (Request, Framed); - type Error = (HandshakeError, Framed); - type InitError = (); - type Service = VerifyWebSockets; - type Future = FutureResult; - - fn new_service(&self, _: &()) -> Self::Future { - ok(VerifyWebSockets { _t: PhantomData }) - } -} - -impl Service for VerifyWebSockets { - type Request = (Request, Framed); - type Response = (Request, Framed); - type Error = (HandshakeError, Framed); - type Future = FutureResult; - - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) - } - - fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { - match verify_handshake(&req) { - Err(e) => Err((e, framed)).into_future(), - Ok(_) => Ok((req, framed)).into_future(), - } - } -} diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index b6be748bd..65a4d094d 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -9,7 +9,7 @@ use futures::{Future, Sink, Stream}; fn ws_service( (req, framed): (Request, Framed), ) -> impl Future { - let res = ws::handshake(&req).unwrap().message_body(()); + let res = ws::handshake(req.head()).unwrap().message_body(()); framed .send((res, body::BodySize::None).into()) diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 04a6a110a..4fc9f4bdc 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -11,7 +11,7 @@ use futures::future::{ok, Either}; use futures::{Future, Sink, Stream}; use tokio_tcp::TcpStream; -use actix_http::{body::BodySize, h1, ws, ResponseError, ServiceConfig}; +use actix_http::{body::BodySize, h1, ws, Request, ResponseError, ServiceConfig}; fn ws_service(req: ws::Frame) -> impl Future { match req { @@ -40,47 +40,49 @@ fn test_simple() { fn_service(|io: Io| Ok(io.into_parts().0)) .and_then(IntoFramed::new(|| h1::Codec::new(ServiceConfig::default()))) .and_then(TakeItem::new().map_err(|_| ())) - .and_then(|(req, framed): (_, Framed<_, _>)| { - // validate request - if let Some(h1::Message::Item(req)) = req { - match ws::verify_handshake(&req) { - Err(e) => { - // validation failed - let res = e.error_response(); - Either::A( - framed - .send(h1::Message::Item(( - res.drop_body(), - BodySize::Empty, - ))) - .map_err(|_| ()) - .map(|_| ()), - ) - } - Ok(_) => { - let res = ws::handshake_response(&req).finish(); - Either::B( - // send handshake response - framed - .send(h1::Message::Item(( - res.drop_body(), - BodySize::None, - ))) - .map_err(|_| ()) - .and_then(|framed| { - // start websocket service - let framed = - framed.into_framed(ws::Codec::new()); - ws::Transport::with(framed, ws_service) - .map_err(|_| ()) - }), - ) + .and_then( + |(req, framed): (Option>, Framed<_, _>)| { + // validate request + if let Some(h1::Message::Item(req)) = req { + match ws::verify_handshake(req.head()) { + Err(e) => { + // validation failed + let res = e.error_response(); + Either::A( + framed + .send(h1::Message::Item(( + res.drop_body(), + BodySize::Empty, + ))) + .map_err(|_| ()) + .map(|_| ()), + ) + } + Ok(_) => { + let res = ws::handshake_response(req.head()).finish(); + Either::B( + // send handshake response + framed + .send(h1::Message::Item(( + res.drop_body(), + BodySize::None, + ))) + .map_err(|_| ()) + .and_then(|framed| { + // start websocket service + let framed = + framed.into_framed(ws::Codec::new()); + ws::Transport::with(framed, ws_service) + .map_err(|_| ()) + }), + ) + } } + } else { + panic!() } - } else { - panic!() - } - }) + }, + ) }); // client service