diff --git a/src/error.rs b/src/error.rs index 465b8ae0..3c090203 100644 --- a/src/error.rs +++ b/src/error.rs @@ -632,6 +632,13 @@ where } } +/// Convert Response to a Error +impl From for Error { + fn from(res: Response) -> Error { + InternalError::from_response("", res).into() + } +} + /// Helper function that creates wrapper of any error and generate *BAD /// REQUEST* response. #[allow(non_snake_case)] diff --git a/src/h1/service.rs b/src/h1/service.rs index 404ded6b..096cb301 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -267,7 +267,10 @@ pub struct OneRequest { _t: PhantomData, } -impl OneRequest { +impl OneRequest +where + T: AsyncRead + AsyncWrite, +{ /// Create new `H1SimpleService` instance. pub fn new() -> Self { OneRequest { diff --git a/src/lib.rs b/src/lib.rs index 5ce3ec39..c9cebb47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,6 +110,7 @@ mod json; mod payload; mod request; mod response; +mod service; pub mod uri; pub mod error; @@ -123,6 +124,7 @@ pub use extensions::Extensions; pub use httpmessage::HttpMessage; pub use request::Request; pub use response::Response; +pub use service::{SendError, SendResponse}; pub use self::config::{KeepAlive, ServiceConfig, ServiceConfigBuilder}; diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 00000000..4467087b --- /dev/null +++ b/src/service.rs @@ -0,0 +1,185 @@ +use std::io; +use std::marker::PhantomData; + +use actix_net::codec::Framed; +use actix_net::service::{NewService, Service}; +use futures::future::{ok, Either, FutureResult}; +use futures::{Async, AsyncSink, Future, Poll, Sink}; +use tokio_io::AsyncWrite; + +use error::ResponseError; +use h1::{Codec, OutMessage}; +use response::Response; + +pub struct SendError(PhantomData<(T, R, E)>); + +impl Default for SendError +where + T: AsyncWrite, + E: ResponseError, +{ + fn default() -> Self { + SendError(PhantomData) + } +} + +impl NewService for SendError +where + T: AsyncWrite, + E: ResponseError, +{ + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type InitError = (); + type Service = SendError; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(SendError(PhantomData)) + } +} + +impl Service for SendError +where + T: AsyncWrite, + E: ResponseError, +{ + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type Future = Either)>, SendErrorFut>; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + match req { + Ok(r) => Either::A(ok(r)), + Err((e, framed)) => Either::B(SendErrorFut { + framed: Some(framed), + res: Some(OutMessage::Response(e.error_response())), + err: Some(e), + _t: PhantomData, + }), + } + } +} + +pub struct SendErrorFut { + res: Option, + framed: Option>, + err: Option, + _t: PhantomData, +} + +impl Future for SendErrorFut +where + E: ResponseError, + T: AsyncWrite, +{ + type Item = R; + type Error = (E, Framed); + + fn poll(&mut self) -> Poll { + if let Some(res) = self.res.take() { + match self.framed.as_mut().unwrap().start_send(res) { + Ok(AsyncSink::Ready) => (), + Ok(AsyncSink::NotReady(res)) => { + self.res = Some(res); + return Ok(Async::NotReady); + } + Err(_) => { + return Err((self.err.take().unwrap(), self.framed.take().unwrap())) + } + } + } + match self.framed.as_mut().unwrap().poll_complete() { + Ok(Async::Ready(_)) => { + return Err((self.err.take().unwrap(), self.framed.take().unwrap())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => { + return Err((self.err.take().unwrap(), self.framed.take().unwrap())) + } + } + } +} + +pub struct SendResponse(PhantomData<(T,)>); + +impl Default for SendResponse +where + T: AsyncWrite, +{ + fn default() -> Self { + SendResponse(PhantomData) + } +} + +impl NewService for SendResponse +where + T: AsyncWrite, +{ + type Request = (Response, Framed); + type Response = Framed; + type Error = io::Error; + type InitError = (); + type Service = SendResponse; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(SendResponse(PhantomData)) + } +} + +impl Service for SendResponse +where + T: AsyncWrite, +{ + type Request = (Response, Framed); + type Response = Framed; + type Error = io::Error; + type Future = SendResponseFut; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, (res, framed): Self::Request) -> Self::Future { + SendResponseFut { + res: Some(OutMessage::Response(res)), + framed: Some(framed), + } + } +} + +pub struct SendResponseFut { + res: Option, + framed: Option>, +} + +impl Future for SendResponseFut +where + T: AsyncWrite, +{ + type Item = Framed; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + if let Some(res) = self.res.take() { + match self.framed.as_mut().unwrap().start_send(res)? { + AsyncSink::Ready => (), + AsyncSink::NotReady(res) => { + self.res = Some(res); + return Ok(Async::NotReady); + } + } + } + match self.framed.as_mut().unwrap().poll_complete()? { + Async::Ready(_) => Ok(Async::Ready(self.framed.take().unwrap())), + Async::NotReady => Ok(Async::NotReady), + } + } +} diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 7df1f4b4..690c56fe 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -14,11 +14,13 @@ mod codec; mod frame; mod mask; mod proto; +mod service; mod transport; pub use self::codec::{Codec, Frame, Message}; pub use self::frame::Parser; pub use self::proto::{CloseCode, CloseReason, OpCode}; +pub use self::service::VerifyWebSockets; pub use self::transport::Transport; /// Websocket protocol errors @@ -109,15 +111,20 @@ impl ResponseError for HandshakeError { } } -/// Prepare `WebSocket` handshake response. -/// -/// This function returns handshake `Response`, ready to send to peer. -/// It does not perform any IO. -/// +/// Verify `WebSocket` handshake request and create handshake reponse. // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. pub fn handshake(req: &Request) -> Result { + verify_handshake(req)?; + Ok(handshake_response(req)) +} + +/// Verify `WebSocket` handshake request. +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { // WebSocket accepts only GET if *req.method() != Method::GET { return Err(HandshakeError::GetMethodRequired); @@ -161,17 +168,24 @@ pub fn handshake(req: &Request) -> Result { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } + Ok(()) +} + +/// Create websocket's handshake response +/// +/// This function returns handshake `Response`, ready to send to peer. +pub fn handshake_response(req: &Request) -> ResponseBuilder { let key = { let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); proto::hash_key(key.as_ref()) }; - Ok(Response::build(StatusCode::SWITCHING_PROTOCOLS) + Response::build(StatusCode::SWITCHING_PROTOCOLS) .connection_type(ConnectionType::Upgrade) .header(header::UPGRADE, "websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) - .take()) + .take() } #[cfg(test)] @@ -185,13 +199,13 @@ mod tests { let req = TestRequest::default().method(Method::POST).finish(); assert_eq!( HandshakeError::GetMethodRequired, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default().finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -199,7 +213,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -209,7 +223,7 @@ mod tests { ).finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -222,7 +236,7 @@ mod tests { ).finish(); assert_eq!( HandshakeError::NoVersionHeader, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -238,7 +252,7 @@ mod tests { ).finish(); assert_eq!( HandshakeError::UnsupportedVersion, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -254,7 +268,7 @@ mod tests { ).finish(); assert_eq!( HandshakeError::BadWebsocketKey, - handshake(&req).err().unwrap() + verify_handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -273,7 +287,7 @@ mod tests { ).finish(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, - handshake(&req).unwrap().finish().status() + handshake_response(&req).finish().status() ); } diff --git a/src/ws/service.rs b/src/ws/service.rs new file mode 100644 index 00000000..9cce4d63 --- /dev/null +++ b/src/ws/service.rs @@ -0,0 +1,52 @@ +use std::marker::PhantomData; + +use actix_net::codec::Framed; +use actix_net::service::{NewService, Service}; +use futures::future::{ok, FutureResult}; +use futures::{Async, IntoFuture, Poll}; + +use h1::Codec; +use request::Request; + +use super::{verify_handshake, HandshakeError}; + +pub struct VerifyWebSockets { + _t: PhantomData, +} + +impl Default for VerifyWebSockets { + fn default() -> Self { + VerifyWebSockets { _t: PhantomData } + } +} + +impl NewService for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type InitError = (); + type Service = VerifyWebSockets; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(VerifyWebSockets { _t: PhantomData }) + } +} + +impl Service for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, (req, framed): Self::Request) -> Self::Future { + match verify_handshake(&req) { + Err(e) => Err((e, framed)).into_future(), + Ok(_) => Ok((req, framed)).into_future(), + } + } +} diff --git a/tests/test_ws.rs b/tests/test_ws.rs index f475cd22..a5503c20 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -52,7 +52,7 @@ fn test_simple() { .and_then(|(req, framed): (_, Framed<_, _>)| { // validate request if let Some(h1::InMessage::Message(req, _)) = req { - match ws::handshake(&req) { + match ws::verify_handshake(&req) { Err(e) => { // validation failed let resp = e.error_response(); @@ -63,11 +63,12 @@ fn test_simple() { .map(|_| ()), ) } - Ok(mut resp) => Either::B( + Ok(_) => Either::B( // send response framed - .send(h1::OutMessage::Response(resp.finish())) - .map_err(|_| ()) + .send(h1::OutMessage::Response( + ws::handshake_response(&req).finish(), + )).map_err(|_| ()) .and_then(|framed| { // start websocket service let framed =