From 5dcb558f5000077a51aa50ffabae83a1ced21113 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 27 Feb 2018 10:09:24 -0800 Subject: [PATCH] refactor websockets handling --- examples/websocket-chat/src/main.rs | 5 +- examples/websocket/src/client.rs | 4 +- examples/websocket/src/main.rs | 5 +- guide/src/qs_9.md | 5 +- src/error.rs | 94 +--------------- src/ws/client.rs | 126 ++++++++++++--------- src/ws/context.rs | 2 +- src/ws/frame.rs | 48 +++++--- src/ws/mod.rs | 165 ++++++++++++++++++++++++---- tests/test_ws.rs | 3 +- 10 files changed, 265 insertions(+), 192 deletions(-) diff --git a/examples/websocket-chat/src/main.rs b/examples/websocket-chat/src/main.rs index 50e6bff5..b6783e83 100644 --- a/examples/websocket-chat/src/main.rs +++ b/examples/websocket-chat/src/main.rs @@ -92,8 +92,7 @@ impl Handler for WsChatSession { } /// WebSocket message handler -impl Handler for WsChatSession { - type Result = (); +impl StreamHandler for WsChatSession { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { println!("WEBSOCKET MESSAGE: {:?}", msg); @@ -161,7 +160,7 @@ impl Handler for WsChatSession { }, ws::Message::Binary(bin) => println!("Unexpected binary"), - ws::Message::Close(_) | ws::Message::Error => { + ws::Message::Close(_) => { ctx.stop(); } } diff --git a/examples/websocket/src/client.rs b/examples/websocket/src/client.rs index 8dc410a7..dddc53b7 100644 --- a/examples/websocket/src/client.rs +++ b/examples/websocket/src/client.rs @@ -12,7 +12,7 @@ use std::time::Duration; use actix::*; use futures::Future; -use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; +use actix_web::ws::{Message, WsError, WsClient, WsClientWriter}; fn main() { @@ -93,7 +93,7 @@ impl Handler for ChatClient { } /// Handle server websocket messages -impl StreamHandler for ChatClient { +impl StreamHandler for ChatClient { fn handle(&mut self, msg: Message, ctx: &mut Context) { match msg { diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index 620e56d4..7e082454 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -25,8 +25,7 @@ impl Actor for MyWebSocket { } /// Handler for `ws::Message` -impl Handler for MyWebSocket { - type Result = (); +impl StreamHandler for MyWebSocket { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { // process websocket messages @@ -35,7 +34,7 @@ impl Handler for MyWebSocket { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(_) | ws::Message::Error => { + ws::Message::Close(_) => { ctx.stop(); } _ => (), diff --git a/guide/src/qs_9.md b/guide/src/qs_9.md index 70d1e018..8200435e 100644 --- a/guide/src/qs_9.md +++ b/guide/src/qs_9.md @@ -21,9 +21,8 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -/// Define Handler for ws::Message message -impl Handler for Ws { - type Result=(); +/// Handler for ws::Message message +impl StreamHandler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg { diff --git a/src/error.rs b/src/error.rs index 4cc67486..6c50db25 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,7 +24,7 @@ use body::Body; use handler::Responder; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use httpcodes::{self, HTTPBadRequest, HTTPMethodNotAllowed, HTTPExpectationFailed}; +use httpcodes::{self, HTTPExpectationFailed}; /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// for actix web operations @@ -341,82 +341,6 @@ impl ResponseError for ExpectError { } } -/// Websocket handshake errors -#[derive(Fail, PartialEq, Debug)] -pub enum WsHandshakeError { - /// Only get method is allowed - #[fail(display="Method not allowed")] - GetMethodRequired, - /// Upgrade header if not set to websocket - #[fail(display="Websocket upgrade is expected")] - NoWebsocketUpgrade, - /// Connection header is not set to upgrade - #[fail(display="Connection upgrade is expected")] - NoConnectionUpgrade, - /// Websocket version header is not set - #[fail(display="Websocket version header is required")] - NoVersionHeader, - /// Unsupported websocket version - #[fail(display="Unsupported version")] - UnsupportedVersion, - /// Websocket key is not set or wrong - #[fail(display="Unknown websocket key")] - BadWebsocketKey, -} - -impl ResponseError for WsHandshakeError { - - fn error_response(&self) -> HttpResponse { - match *self { - WsHandshakeError::GetMethodRequired => { - HTTPMethodNotAllowed - .build() - .header(header::ALLOW, "GET") - .finish() - .unwrap() - } - WsHandshakeError::NoWebsocketUpgrade => - HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"), - WsHandshakeError::NoConnectionUpgrade => - HTTPBadRequest.with_reason("No CONNECTION upgrade"), - WsHandshakeError::NoVersionHeader => - HTTPBadRequest.with_reason("Websocket version header is required"), - WsHandshakeError::UnsupportedVersion => - HTTPBadRequest.with_reason("Unsupported version"), - WsHandshakeError::BadWebsocketKey => - HTTPBadRequest.with_reason("Handshake error"), - } - } -} - -/// Websocket errors -#[derive(Fail, Debug)] -pub enum WsError { - /// Received an unmasked frame from client - #[fail(display="Received an unmasked frame from client")] - UnmaskedFrame, - /// Received a masked frame from server - #[fail(display="Received a masked frame from server")] - MaskedFrame, - /// Encountered invalid opcode - #[fail(display="Invalid opcode: {}", _0)] - InvalidOpcode(u8), - /// Invalid control frame length - #[fail(display="Invalid control frame length: {}", _0)] - InvalidLength(usize), - /// Payload error - #[fail(display="Payload error: {}", _0)] - Payload(#[cause] PayloadError), -} - -impl ResponseError for WsError {} - -impl From for WsError { - fn from(err: PayloadError) -> WsError { - WsError::Payload(err) - } -} - /// A set of errors that can occur during parsing urlencoded payloads #[derive(Fail, Debug)] pub enum UrlencodedError { @@ -769,22 +693,6 @@ mod tests { assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED); } - #[test] - fn test_wserror_http_response() { - let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } - macro_rules! from { ($from:expr => $error:pat) => { match ParseError::from($from) { diff --git a/src/ws/client.rs b/src/ws/client.rs index 19e2543c..7b7a0741 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -17,14 +17,14 @@ use byteorder::{ByteOrder, NetworkEndian}; use actix::prelude::*; use body::{Body, Binary}; -use error::{WsError, UrlParseError}; +use error::UrlParseError; use payload::PayloadHelper; use client::{ClientRequest, ClientRequestBuilder, ClientResponse, ClientConnector, SendRequest, SendRequestError, HttpResponseParserError}; -use super::Message; +use super::{Message, WsError}; use super::frame::Frame; use super::proto::{CloseCode, OpCode}; @@ -106,6 +106,7 @@ pub struct WsClient { origin: Option, protocols: Option, conn: Addr, + max_size: usize, } impl WsClient { @@ -123,6 +124,7 @@ impl WsClient { http_err: None, origin: None, protocols: None, + max_size: 65_536, conn, }; cl.request.uri(uri.as_ref()); @@ -158,6 +160,14 @@ impl WsClient { self } + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + /// Set request header pub fn header(mut self, key: K, value: V) -> Self where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom @@ -167,12 +177,12 @@ impl WsClient { } /// Connect to websocket server and do ws handshake - pub fn connect(&mut self) -> WsHandshake { + pub fn connect(&mut self) -> WsClientHandshake { if let Some(e) = self.err.take() { - WsHandshake::new(None, Some(e), &self.conn) + WsClientHandshake::error(e) } else if let Some(e) = self.http_err.take() { - WsHandshake::new(None, Some(e.into()), &self.conn) + WsClientHandshake::error(e.into()) } else { // origin if let Some(origin) = self.origin.take() { @@ -189,23 +199,22 @@ impl WsClient { } let request = match self.request.finish() { Ok(req) => req, - Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn), + Err(err) => return WsClientHandshake::error(err.into()), }; if request.uri().host().is_none() { - return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) + return WsClientHandshake::error(WsClientError::InvalidUrl) } if let Some(scheme) = request.uri().scheme_part() { if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { - return WsHandshake::new( - None, Some(WsClientError::InvalidUrl), &self.conn) + return WsClientHandshake::error(WsClientError::InvalidUrl) } } else { - return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) + return WsClientHandshake::error(WsClientError::InvalidUrl) } // start handshake - WsHandshake::new(Some(request), None, &self.conn) + WsClientHandshake::new(request, &self.conn, self.max_size) } } } @@ -216,17 +225,17 @@ struct WsInner { closed: bool, } -pub struct WsHandshake { +pub struct WsClientHandshake { request: Option, tx: Option>, key: String, error: Option, + max_size: usize, } -impl WsHandshake { - fn new(request: Option, - err: Option, - conn: &Addr) -> WsHandshake +impl WsClientHandshake { + fn new(mut request: ClientRequest, + conn: &Addr, max_size: usize) -> WsClientHandshake { // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, @@ -234,34 +243,36 @@ impl WsHandshake { let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); - if let Some(mut request) = request { - request.headers_mut().insert( - HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), - HeaderValue::try_from(key.as_str()).unwrap()); + request.headers_mut().insert( + HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), + HeaderValue::try_from(key.as_str()).unwrap()); - let (tx, rx) = unbounded(); - request.set_body(Body::Streaming( - Box::new(rx.map_err(|_| io::Error::new( - io::ErrorKind::Other, "disconnected").into())))); + let (tx, rx) = unbounded(); + request.set_body(Body::Streaming( + Box::new(rx.map_err(|_| io::Error::new( + io::ErrorKind::Other, "disconnected").into())))); - WsHandshake { - key, - request: Some(request.with_connector(conn.clone())), - tx: Some(tx), - error: err, - } - } else { - WsHandshake { - key, - request: None, - tx: None, - error: err, - } + WsClientHandshake { + key, + max_size, + request: Some(request.with_connector(conn.clone())), + tx: Some(tx), + error: None, + } + } + + fn error(err: WsClientError) -> WsClientHandshake { + WsClientHandshake { + key: String::new(), + request: None, + tx: None, + error: Some(err), + max_size: 0 } } } -impl Future for WsHandshake { +impl Future for WsClientHandshake { type Item = (WsClientReader, WsClientWriter); type Error = WsClientError; @@ -330,13 +341,15 @@ impl Future for WsHandshake { let inner = Rc::new(UnsafeCell::new(inner)); Ok(Async::Ready( - (WsClientReader{inner: Rc::clone(&inner)}, WsClientWriter{inner}))) + (WsClientReader{inner: Rc::clone(&inner), max_size: self.max_size}, + WsClientWriter{inner}))) } } pub struct WsClientReader { - inner: Rc> + inner: Rc>, + max_size: usize, } impl fmt::Debug for WsClientReader { @@ -354,29 +367,36 @@ impl WsClientReader { impl Stream for WsClientReader { type Item = Message; - type Error = WsClientError; + type Error = WsError; fn poll(&mut self) -> Poll, Self::Error> { + let max_size = self.max_size; let inner = self.as_mut(); if inner.closed { return Ok(Async::Ready(None)) } // read - match Frame::parse(&mut inner.rx, false) { + match Frame::parse(&mut inner.rx, false, max_size) { Ok(Async::Ready(Some(frame))) => { - // trace!("WsFrame {}", frame); - let (_finished, opcode, payload) = frame.unpack(); + let (finished, opcode, payload) = frame.unpack(); + + // continuation is not supported + if !finished { + inner.closed = true; + return Err(WsError::NoContinuation) + } match opcode { OpCode::Continue => unimplemented!(), - OpCode::Bad => - Ok(Async::Ready(Some(Message::Error))), + OpCode::Bad => { + inner.closed = true; + Err(WsError::BadOpCode) + }, OpCode::Close => { inner.closed = true; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; - Ok(Async::Ready( - Some(Message::Close(CloseCode::from(code))))) + Ok(Async::Ready(Some(Message::Close(CloseCode::from(code))))) }, OpCode::Ping => Ok(Async::Ready(Some( @@ -393,17 +413,19 @@ impl Stream for WsClientReader { match String::from_utf8(tmp) { Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => - Ok(Async::Ready(Some(Message::Error))), + Err(_) => { + inner.closed = true; + Err(WsError::BadEncoding) + } } } } } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { + Err(e) => { inner.closed = true; - Err(err.into()) + Err(e) } } } diff --git a/src/ws/context.rs b/src/ws/context.rs index b9214b74..8720b461 100644 --- a/src/ws/context.rs +++ b/src/ws/context.rs @@ -18,7 +18,7 @@ use ws::frame::Frame; use ws::proto::{OpCode, CloseCode}; -/// Http actor execution context +/// `WebSockets` actor execution context pub struct WebsocketContext where A: Actor>, { inner: ContextImpl, diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 7c573b71..32056658 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -6,8 +6,10 @@ use futures::{Async, Poll, Stream}; use rand; use body::Binary; -use error::{WsError, PayloadError}; +use error::{PayloadError}; use payload::PayloadHelper; + +use ws::WsError; use ws::proto::{OpCode, CloseCode}; use ws::mask::apply_mask; @@ -50,7 +52,8 @@ impl Frame { } /// Parse the input stream into a frame. - pub fn parse(pl: &mut PayloadHelper, server: bool) -> Poll, WsError> + pub fn parse(pl: &mut PayloadHelper, server: bool, max_size: usize) + -> Poll, WsError> where S: Stream { let mut idx = 2; @@ -99,6 +102,11 @@ impl Frame { len as usize }; + // check for max allowed size + if length > max_size { + return Err(WsError::Overflow) + } + let mask = if server { let buf = match pl.copy(idx + 4)? { Async::Ready(Some(buf)) => buf, @@ -267,13 +275,13 @@ mod tests { fn test_parse() { let mut buf = PayloadHelper::new( once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze()))); - assert!(is_none(Frame::parse(&mut buf, false))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); buf.extend(b"1"); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = extract(Frame::parse(&mut buf, false)); + let frame = extract(Frame::parse(&mut buf, false, 1024)); println!("FRAME: {}", frame); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); @@ -285,7 +293,7 @@ mod tests { let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = extract(Frame::parse(&mut buf, false)); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert!(frame.payload.is_empty()); @@ -295,14 +303,14 @@ mod tests { fn test_parse_length2() { let buf = BytesMut::from(&[0b00000001u8, 126u8][..]); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(is_none(Frame::parse(&mut buf, false))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]); buf.extend(&[0u8, 4u8][..]); buf.extend(b"1234"); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = extract(Frame::parse(&mut buf, false)); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -312,14 +320,14 @@ mod tests { fn test_parse_length4() { let buf = BytesMut::from(&[0b00000001u8, 127u8][..]); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(is_none(Frame::parse(&mut buf, false))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(b"1234"); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = extract(Frame::parse(&mut buf, false)); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -332,9 +340,9 @@ mod tests { buf.extend(b"1"); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(Frame::parse(&mut buf, false).is_err()); + assert!(Frame::parse(&mut buf, false, 1024).is_err()); - let frame = extract(Frame::parse(&mut buf, true)); + let frame = extract(Frame::parse(&mut buf, true, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); @@ -346,14 +354,28 @@ mod tests { buf.extend(&[1u8]); let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(Frame::parse(&mut buf, true).is_err()); + assert!(Frame::parse(&mut buf, true, 1024).is_err()); - let frame = extract(Frame::parse(&mut buf, false)); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); } + #[test] + fn test_parse_frame_max_size() { + let mut buf = BytesMut::from(&[0b00000001u8, 0b00000010u8][..]); + buf.extend(&[1u8, 1u8]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + assert!(Frame::parse(&mut buf, true, 1).is_err()); + + if let Err(WsError::Overflow) = Frame::parse(&mut buf, false, 0) { + } else { + panic!("error"); + } + } + #[test] fn test_ping_frame() { let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); diff --git a/src/ws/mod.rs b/src/ws/mod.rs index ef9eaf32..b2f0da3c 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -23,9 +23,8 @@ //! type Context = ws::WebsocketContext; //! } //! -//! // Define Handler for ws::Message message -//! impl Handler for Ws { -//! type Result = (); +//! // Handler for ws::Message messages +//! impl StreamHandler for Ws { //! //! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { //! match msg { @@ -48,13 +47,14 @@ use http::{Method, StatusCode, header}; use futures::{Async, Poll, Stream}; use byteorder::{ByteOrder, NetworkEndian}; -use actix::{Actor, AsyncContext, Handler}; +use actix::{Actor, AsyncContext, StreamHandler}; use body::Binary; use payload::PayloadHelper; -use error::{Error, WsHandshakeError, PayloadError}; +use error::{Error, PayloadError, ResponseError}; use httprequest::HttpRequest; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; +use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed}; mod frame; mod proto; @@ -66,7 +66,8 @@ use self::frame::Frame; use self::proto::{hash_key, OpCode}; pub use self::proto::CloseCode; pub use self::context::WebsocketContext; -pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake}; +pub use self::client::{WsClient, WsClientError, + WsClientReader, WsClientWriter, WsClientHandshake}; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; @@ -74,6 +75,94 @@ const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION"; // const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL"; +/// Websocket errors +#[derive(Fail, Debug)] +pub enum WsError { + /// Received an unmasked frame from client + #[fail(display="Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[fail(display="Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[fail(display="Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[fail(display="Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Bad web socket op code + #[fail(display="Bad web socket op code")] + BadOpCode, + /// A payload reached size limit. + #[fail(display="A payload reached size limit.")] + Overflow, + /// Continuation is not supproted + #[fail(display="Continuation is not supproted.")] + NoContinuation, + /// Bad utf-8 encoding + #[fail(display="Bad utf-8 encoding.")] + BadEncoding, + /// Payload error + #[fail(display="Payload error: {}", _0)] + Payload(#[cause] PayloadError), +} + +impl ResponseError for WsError {} + +impl From for WsError { + fn from(err: PayloadError) -> WsError { + WsError::Payload(err) + } +} + +/// Websocket handshake errors +#[derive(Fail, PartialEq, Debug)] +pub enum WsHandshakeError { + /// Only get method is allowed + #[fail(display="Method not allowed")] + GetMethodRequired, + /// Upgrade header if not set to websocket + #[fail(display="Websocket upgrade is expected")] + NoWebsocketUpgrade, + /// Connection header is not set to upgrade + #[fail(display="Connection upgrade is expected")] + NoConnectionUpgrade, + /// Websocket version header is not set + #[fail(display="Websocket version header is required")] + NoVersionHeader, + /// Unsupported websocket version + #[fail(display="Unsupported version")] + UnsupportedVersion, + /// Websocket key is not set or wrong + #[fail(display="Unknown websocket key")] + BadWebsocketKey, +} + +impl ResponseError for WsHandshakeError { + + fn error_response(&self) -> HttpResponse { + match *self { + WsHandshakeError::GetMethodRequired => { + HTTPMethodNotAllowed + .build() + .header(header::ALLOW, "GET") + .finish() + .unwrap() + } + WsHandshakeError::NoWebsocketUpgrade => + HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"), + WsHandshakeError::NoConnectionUpgrade => + HTTPBadRequest.with_reason("No CONNECTION upgrade"), + WsHandshakeError::NoVersionHeader => + HTTPBadRequest.with_reason("Websocket version header is required"), + WsHandshakeError::UnsupportedVersion => + HTTPBadRequest.with_reason("Unsupported version"), + WsHandshakeError::BadWebsocketKey => + HTTPBadRequest.with_reason("Handshake error"), + } + } +} + /// `WebSocket` Message #[derive(Debug, PartialEq, Message)] pub enum Message { @@ -82,19 +171,18 @@ pub enum Message { Ping(String), Pong(String), Close(CloseCode), - Error } /// Do websocket handshake and start actor pub fn start(req: HttpRequest, actor: A) -> Result - where A: Actor> + Handler, + where A: Actor> + StreamHandler, S: 'static { let mut resp = handshake(&req)?; let stream = WsStream::new(req.clone()); let mut ctx = WebsocketContext::new(req, actor); - ctx.add_message_stream(stream); + ctx.add_stream(stream); Ok(resp.body(ctx)?) } @@ -168,33 +256,52 @@ pub fn handshake(req: &HttpRequest) -> Result { rx: PayloadHelper, closed: bool, + max_size: usize, } impl WsStream where S: Stream { + /// Create new websocket frames stream pub fn new(stream: S) -> WsStream { WsStream { rx: PayloadHelper::new(stream), - closed: false } + closed: false, + max_size: 65_536, + } + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self } } impl Stream for WsStream where S: Stream { type Item = Message; - type Error = (); + type Error = WsError; fn poll(&mut self) -> Poll, Self::Error> { if self.closed { return Ok(Async::Ready(None)) } - match Frame::parse(&mut self.rx, true) { + match Frame::parse(&mut self.rx, true, self.max_size) { Ok(Async::Ready(Some(frame))) => { - // trace!("WsFrame {}", frame); - let (_finished, opcode, payload) = frame.unpack(); + let (finished, opcode, payload) = frame.unpack(); + + // continuation is not supported + if !finished { + self.closed = true; + return Err(WsError::NoContinuation) + } match opcode { OpCode::Continue => unimplemented!(), - OpCode::Bad => - Ok(Async::Ready(Some(Message::Error))), + OpCode::Bad => { + self.closed = true; + Err(WsError::BadOpCode) + } OpCode::Close => { self.closed = true; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; @@ -215,17 +322,19 @@ impl Stream for WsStream where S: Stream { match String::from_utf8(tmp) { Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => - Ok(Async::Ready(Some(Message::Error))), + Err(_) => { + self.closed = true; + Err(WsError::BadEncoding) + } } } } } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => { + Err(e) => { self.closed = true; - Ok(Async::Ready(Some(Message::Error))) + Err(e) } } } @@ -306,4 +415,20 @@ mod tests { assert_eq!(StatusCode::SWITCHING_PROTOCOLS, handshake(&req).unwrap().finish().unwrap().status()); } + + #[test] + fn test_wserror_http_response() { + let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } } diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 8b7d0111..13aeef48 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -16,8 +16,7 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -impl Handler for Ws { - type Result = (); +impl StreamHandler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg {