1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 01:32:57 +01:00

refactor websockets handling

This commit is contained in:
Nikolay Kim 2018-02-27 10:09:24 -08:00
parent a344c3a02e
commit 5dcb558f50
10 changed files with 265 additions and 192 deletions

View File

@ -92,8 +92,7 @@ impl Handler<session::Message> for WsChatSession {
} }
/// WebSocket message handler /// WebSocket message handler
impl Handler<ws::Message> for WsChatSession { impl StreamHandler<ws::Message, ws::WsError> for WsChatSession {
type Result = ();
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
println!("WEBSOCKET MESSAGE: {:?}", msg); println!("WEBSOCKET MESSAGE: {:?}", msg);
@ -161,7 +160,7 @@ impl Handler<ws::Message> for WsChatSession {
}, },
ws::Message::Binary(bin) => ws::Message::Binary(bin) =>
println!("Unexpected binary"), println!("Unexpected binary"),
ws::Message::Close(_) | ws::Message::Error => { ws::Message::Close(_) => {
ctx.stop(); ctx.stop();
} }
} }

View File

@ -12,7 +12,7 @@ use std::time::Duration;
use actix::*; use actix::*;
use futures::Future; use futures::Future;
use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; use actix_web::ws::{Message, WsError, WsClient, WsClientWriter};
fn main() { fn main() {
@ -93,7 +93,7 @@ impl Handler<ClientCommand> for ChatClient {
} }
/// Handle server websocket messages /// Handle server websocket messages
impl StreamHandler<Message, WsClientError> for ChatClient { impl StreamHandler<Message, WsError> for ChatClient {
fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) { fn handle(&mut self, msg: Message, ctx: &mut Context<Self>) {
match msg { match msg {

View File

@ -25,8 +25,7 @@ impl Actor for MyWebSocket {
} }
/// Handler for `ws::Message` /// Handler for `ws::Message`
impl Handler<ws::Message> for MyWebSocket { impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
type Result = ();
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
// process websocket messages // process websocket messages
@ -35,7 +34,7 @@ impl Handler<ws::Message> for MyWebSocket {
ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Ping(msg) => ctx.pong(&msg),
ws::Message::Text(text) => ctx.text(text), ws::Message::Text(text) => ctx.text(text),
ws::Message::Binary(bin) => ctx.binary(bin), ws::Message::Binary(bin) => ctx.binary(bin),
ws::Message::Close(_) | ws::Message::Error => { ws::Message::Close(_) => {
ctx.stop(); ctx.stop();
} }
_ => (), _ => (),

View File

@ -21,9 +21,8 @@ impl Actor for Ws {
type Context = ws::WebsocketContext<Self>; type Context = ws::WebsocketContext<Self>;
} }
/// Define Handler for ws::Message message /// Handler for ws::Message message
impl Handler<ws::Message> for Ws { impl StreamHandler<ws::Message, ws::WsError> for Ws {
type Result=();
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {

View File

@ -24,7 +24,7 @@ use body::Body;
use handler::Responder; use handler::Responder;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use httpcodes::{self, HTTPBadRequest, HTTPMethodNotAllowed, HTTPExpectationFailed}; use httpcodes::{self, HTTPExpectationFailed};
/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html)
/// for actix web operations /// for actix web operations
@ -341,82 +341,6 @@ impl ResponseError for ExpectError {
} }
} }
/// Websocket handshake errors
#[derive(Fail, PartialEq, Debug)]
pub enum WsHandshakeError {
/// Only get method is allowed
#[fail(display="Method not allowed")]
GetMethodRequired,
/// Upgrade header if not set to websocket
#[fail(display="Websocket upgrade is expected")]
NoWebsocketUpgrade,
/// Connection header is not set to upgrade
#[fail(display="Connection upgrade is expected")]
NoConnectionUpgrade,
/// Websocket version header is not set
#[fail(display="Websocket version header is required")]
NoVersionHeader,
/// Unsupported websocket version
#[fail(display="Unsupported version")]
UnsupportedVersion,
/// Websocket key is not set or wrong
#[fail(display="Unknown websocket key")]
BadWebsocketKey,
}
impl ResponseError for WsHandshakeError {
fn error_response(&self) -> HttpResponse {
match *self {
WsHandshakeError::GetMethodRequired => {
HTTPMethodNotAllowed
.build()
.header(header::ALLOW, "GET")
.finish()
.unwrap()
}
WsHandshakeError::NoWebsocketUpgrade =>
HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"),
WsHandshakeError::NoConnectionUpgrade =>
HTTPBadRequest.with_reason("No CONNECTION upgrade"),
WsHandshakeError::NoVersionHeader =>
HTTPBadRequest.with_reason("Websocket version header is required"),
WsHandshakeError::UnsupportedVersion =>
HTTPBadRequest.with_reason("Unsupported version"),
WsHandshakeError::BadWebsocketKey =>
HTTPBadRequest.with_reason("Handshake error"),
}
}
}
/// Websocket errors
#[derive(Fail, Debug)]
pub enum WsError {
/// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")]
UnmaskedFrame,
/// Received a masked frame from server
#[fail(display="Received a masked frame from server")]
MaskedFrame,
/// Encountered invalid opcode
#[fail(display="Invalid opcode: {}", _0)]
InvalidOpcode(u8),
/// Invalid control frame length
#[fail(display="Invalid control frame length: {}", _0)]
InvalidLength(usize),
/// Payload error
#[fail(display="Payload error: {}", _0)]
Payload(#[cause] PayloadError),
}
impl ResponseError for WsError {}
impl From<PayloadError> for WsError {
fn from(err: PayloadError) -> WsError {
WsError::Payload(err)
}
}
/// A set of errors that can occur during parsing urlencoded payloads /// A set of errors that can occur during parsing urlencoded payloads
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum UrlencodedError { pub enum UrlencodedError {
@ -769,22 +693,6 @@ mod tests {
assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED); assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED);
} }
#[test]
fn test_wserror_http_response() {
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
macro_rules! from { macro_rules! from {
($from:expr => $error:pat) => { ($from:expr => $error:pat) => {
match ParseError::from($from) { match ParseError::from($from) {

View File

@ -17,14 +17,14 @@ use byteorder::{ByteOrder, NetworkEndian};
use actix::prelude::*; use actix::prelude::*;
use body::{Body, Binary}; use body::{Body, Binary};
use error::{WsError, UrlParseError}; use error::UrlParseError;
use payload::PayloadHelper; use payload::PayloadHelper;
use client::{ClientRequest, ClientRequestBuilder, ClientResponse, use client::{ClientRequest, ClientRequestBuilder, ClientResponse,
ClientConnector, SendRequest, SendRequestError, ClientConnector, SendRequest, SendRequestError,
HttpResponseParserError}; HttpResponseParserError};
use super::Message; use super::{Message, WsError};
use super::frame::Frame; use super::frame::Frame;
use super::proto::{CloseCode, OpCode}; use super::proto::{CloseCode, OpCode};
@ -106,6 +106,7 @@ pub struct WsClient {
origin: Option<HeaderValue>, origin: Option<HeaderValue>,
protocols: Option<String>, protocols: Option<String>,
conn: Addr<Unsync, ClientConnector>, conn: Addr<Unsync, ClientConnector>,
max_size: usize,
} }
impl WsClient { impl WsClient {
@ -123,6 +124,7 @@ impl WsClient {
http_err: None, http_err: None,
origin: None, origin: None,
protocols: None, protocols: None,
max_size: 65_536,
conn, conn,
}; };
cl.request.uri(uri.as_ref()); cl.request.uri(uri.as_ref());
@ -158,6 +160,14 @@ impl WsClient {
self self
} }
/// Set max frame size
///
/// By default max size is set to 64kb
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Set request header /// Set request header
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, HeaderValue: HttpTryFrom<V> where HeaderName: HttpTryFrom<K>, HeaderValue: HttpTryFrom<V>
@ -167,12 +177,12 @@ impl WsClient {
} }
/// Connect to websocket server and do ws handshake /// Connect to websocket server and do ws handshake
pub fn connect(&mut self) -> WsHandshake { pub fn connect(&mut self) -> WsClientHandshake {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
WsHandshake::new(None, Some(e), &self.conn) WsClientHandshake::error(e)
} }
else if let Some(e) = self.http_err.take() { else if let Some(e) = self.http_err.take() {
WsHandshake::new(None, Some(e.into()), &self.conn) WsClientHandshake::error(e.into())
} else { } else {
// origin // origin
if let Some(origin) = self.origin.take() { if let Some(origin) = self.origin.take() {
@ -189,23 +199,22 @@ impl WsClient {
} }
let request = match self.request.finish() { let request = match self.request.finish() {
Ok(req) => req, Ok(req) => req,
Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn), Err(err) => return WsClientHandshake::error(err.into()),
}; };
if request.uri().host().is_none() { if request.uri().host().is_none() {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) return WsClientHandshake::error(WsClientError::InvalidUrl)
} }
if let Some(scheme) = request.uri().scheme_part() { if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return WsHandshake::new( return WsClientHandshake::error(WsClientError::InvalidUrl)
None, Some(WsClientError::InvalidUrl), &self.conn)
} }
} else { } else {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) return WsClientHandshake::error(WsClientError::InvalidUrl)
} }
// start handshake // start handshake
WsHandshake::new(Some(request), None, &self.conn) WsClientHandshake::new(request, &self.conn, self.max_size)
} }
} }
} }
@ -216,17 +225,17 @@ struct WsInner {
closed: bool, closed: bool,
} }
pub struct WsHandshake { pub struct WsClientHandshake {
request: Option<SendRequest>, request: Option<SendRequest>,
tx: Option<UnboundedSender<Bytes>>, tx: Option<UnboundedSender<Bytes>>,
key: String, key: String,
error: Option<WsClientError>, error: Option<WsClientError>,
max_size: usize,
} }
impl WsHandshake { impl WsClientHandshake {
fn new(request: Option<ClientRequest>, fn new(mut request: ClientRequest,
err: Option<WsClientError>, conn: &Addr<Unsync, ClientConnector>, max_size: usize) -> WsClientHandshake
conn: &Addr<Unsync, ClientConnector>) -> WsHandshake
{ {
// Generate a random key for the `Sec-WebSocket-Key` header. // Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that, // a base64-encoded (see Section 4 of [RFC4648]) value that,
@ -234,7 +243,6 @@ impl WsHandshake {
let sec_key: [u8; 16] = rand::random(); let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key); let key = base64::encode(&sec_key);
if let Some(mut request) = request {
request.headers_mut().insert( request.headers_mut().insert(
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
HeaderValue::try_from(key.as_str()).unwrap()); HeaderValue::try_from(key.as_str()).unwrap());
@ -244,24 +252,27 @@ impl WsHandshake {
Box::new(rx.map_err(|_| io::Error::new( Box::new(rx.map_err(|_| io::Error::new(
io::ErrorKind::Other, "disconnected").into())))); io::ErrorKind::Other, "disconnected").into()))));
WsHandshake { WsClientHandshake {
key, key,
max_size,
request: Some(request.with_connector(conn.clone())), request: Some(request.with_connector(conn.clone())),
tx: Some(tx), tx: Some(tx),
error: err, error: None,
} }
} else { }
WsHandshake {
key, fn error(err: WsClientError) -> WsClientHandshake {
WsClientHandshake {
key: String::new(),
request: None, request: None,
tx: None, tx: None,
error: err, error: Some(err),
} max_size: 0
} }
} }
} }
impl Future for WsHandshake { impl Future for WsClientHandshake {
type Item = (WsClientReader, WsClientWriter); type Item = (WsClientReader, WsClientWriter);
type Error = WsClientError; type Error = WsClientError;
@ -330,13 +341,15 @@ impl Future for WsHandshake {
let inner = Rc::new(UnsafeCell::new(inner)); let inner = Rc::new(UnsafeCell::new(inner));
Ok(Async::Ready( Ok(Async::Ready(
(WsClientReader{inner: Rc::clone(&inner)}, WsClientWriter{inner}))) (WsClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
WsClientWriter{inner})))
} }
} }
pub struct WsClientReader { pub struct WsClientReader {
inner: Rc<UnsafeCell<WsInner>> inner: Rc<UnsafeCell<WsInner>>,
max_size: usize,
} }
impl fmt::Debug for WsClientReader { impl fmt::Debug for WsClientReader {
@ -354,29 +367,36 @@ impl WsClientReader {
impl Stream for WsClientReader { impl Stream for WsClientReader {
type Item = Message; type Item = Message;
type Error = WsClientError; type Error = WsError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let max_size = self.max_size;
let inner = self.as_mut(); let inner = self.as_mut();
if inner.closed { if inner.closed {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None))
} }
// read // read
match Frame::parse(&mut inner.rx, false) { match Frame::parse(&mut inner.rx, false, max_size) {
Ok(Async::Ready(Some(frame))) => { Ok(Async::Ready(Some(frame))) => {
// trace!("WsFrame {}", frame); let (finished, opcode, payload) = frame.unpack();
let (_finished, opcode, payload) = frame.unpack();
// continuation is not supported
if !finished {
inner.closed = true;
return Err(WsError::NoContinuation)
}
match opcode { match opcode {
OpCode::Continue => unimplemented!(), OpCode::Continue => unimplemented!(),
OpCode::Bad => OpCode::Bad => {
Ok(Async::Ready(Some(Message::Error))), inner.closed = true;
Err(WsError::BadOpCode)
},
OpCode::Close => { OpCode::Close => {
inner.closed = true; inner.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
Ok(Async::Ready( Ok(Async::Ready(Some(Message::Close(CloseCode::from(code)))))
Some(Message::Close(CloseCode::from(code)))))
}, },
OpCode::Ping => OpCode::Ping =>
Ok(Async::Ready(Some( Ok(Async::Ready(Some(
@ -393,17 +413,19 @@ impl Stream for WsClientReader {
match String::from_utf8(tmp) { match String::from_utf8(tmp) {
Ok(s) => Ok(s) =>
Ok(Async::Ready(Some(Message::Text(s)))), Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => Err(_) => {
Ok(Async::Ready(Some(Message::Error))), inner.closed = true;
Err(WsError::BadEncoding)
}
} }
} }
} }
} }
Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => { Err(e) => {
inner.closed = true; inner.closed = true;
Err(err.into()) Err(e)
} }
} }
} }

View File

@ -18,7 +18,7 @@ use ws::frame::Frame;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{OpCode, CloseCode};
/// Http actor execution context /// `WebSockets` actor execution context
pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A, S>>, pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A, S>>,
{ {
inner: ContextImpl<A>, inner: ContextImpl<A>,

View File

@ -6,8 +6,10 @@ use futures::{Async, Poll, Stream};
use rand; use rand;
use body::Binary; use body::Binary;
use error::{WsError, PayloadError}; use error::{PayloadError};
use payload::PayloadHelper; use payload::PayloadHelper;
use ws::WsError;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask; use ws::mask::apply_mask;
@ -50,7 +52,8 @@ impl Frame {
} }
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool) -> Poll<Option<Frame>, WsError> pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
-> Poll<Option<Frame>, WsError>
where S: Stream<Item=Bytes, Error=PayloadError> where S: Stream<Item=Bytes, Error=PayloadError>
{ {
let mut idx = 2; let mut idx = 2;
@ -99,6 +102,11 @@ impl Frame {
len as usize len as usize
}; };
// check for max allowed size
if length > max_size {
return Err(WsError::Overflow)
}
let mask = if server { let mask = if server {
let buf = match pl.copy(idx + 4)? { let buf = match pl.copy(idx + 4)? {
Async::Ready(Some(buf)) => buf, Async::Ready(Some(buf)) => buf,
@ -267,13 +275,13 @@ mod tests {
fn test_parse() { fn test_parse() {
let mut buf = PayloadHelper::new( let mut buf = PayloadHelper::new(
once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze()))); once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze())));
assert!(is_none(Frame::parse(&mut buf, false))); assert!(is_none(Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]);
buf.extend(b"1"); buf.extend(b"1");
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false)); let frame = extract(Frame::parse(&mut buf, false, 1024));
println!("FRAME: {}", frame); println!("FRAME: {}", frame);
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
@ -285,7 +293,7 @@ mod tests {
let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]);
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert!(frame.payload.is_empty()); assert!(frame.payload.is_empty());
@ -295,14 +303,14 @@ mod tests {
fn test_parse_length2() { fn test_parse_length2() {
let buf = BytesMut::from(&[0b00000001u8, 126u8][..]); let buf = BytesMut::from(&[0b00000001u8, 126u8][..]);
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
assert!(is_none(Frame::parse(&mut buf, false))); assert!(is_none(Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]); let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]);
buf.extend(&[0u8, 4u8][..]); buf.extend(&[0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
@ -312,14 +320,14 @@ mod tests {
fn test_parse_length4() { fn test_parse_length4() {
let buf = BytesMut::from(&[0b00000001u8, 127u8][..]); let buf = BytesMut::from(&[0b00000001u8, 127u8][..]);
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
assert!(is_none(Frame::parse(&mut buf, false))); assert!(is_none(Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]); let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]);
buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
let frame = extract(Frame::parse(&mut buf, false)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
@ -332,9 +340,9 @@ mod tests {
buf.extend(b"1"); buf.extend(b"1");
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, false).is_err()); assert!(Frame::parse(&mut buf, false, 1024).is_err());
let frame = extract(Frame::parse(&mut buf, true)); let frame = extract(Frame::parse(&mut buf, true, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, vec![1u8].into()); assert_eq!(frame.payload, vec![1u8].into());
@ -346,14 +354,28 @@ mod tests {
buf.extend(&[1u8]); buf.extend(&[1u8]);
let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, true).is_err()); assert!(Frame::parse(&mut buf, true, 1024).is_err());
let frame = extract(Frame::parse(&mut buf, false)); let frame = extract(Frame::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, vec![1u8].into()); assert_eq!(frame.payload, vec![1u8].into());
} }
#[test]
fn test_parse_frame_max_size() {
let mut buf = BytesMut::from(&[0b00000001u8, 0b00000010u8][..]);
buf.extend(&[1u8, 1u8]);
let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
assert!(Frame::parse(&mut buf, true, 1).is_err());
if let Err(WsError::Overflow) = Frame::parse(&mut buf, false, 0) {
} else {
panic!("error");
}
}
#[test] #[test]
fn test_ping_frame() { fn test_ping_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false);

View File

@ -23,9 +23,8 @@
//! type Context = ws::WebsocketContext<Self>; //! type Context = ws::WebsocketContext<Self>;
//! } //! }
//! //!
//! // Define Handler for ws::Message message //! // Handler for ws::Message messages
//! impl Handler<ws::Message> for Ws { //! impl StreamHandler<ws::Message, ws::WsError> for Ws {
//! type Result = ();
//! //!
//! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { //! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
//! match msg { //! match msg {
@ -48,13 +47,14 @@ use http::{Method, StatusCode, header};
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, NetworkEndian};
use actix::{Actor, AsyncContext, Handler}; use actix::{Actor, AsyncContext, StreamHandler};
use body::Binary; use body::Binary;
use payload::PayloadHelper; use payload::PayloadHelper;
use error::{Error, WsHandshakeError, PayloadError}; use error::{Error, PayloadError, ResponseError};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder};
use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed};
mod frame; mod frame;
mod proto; mod proto;
@ -66,7 +66,8 @@ use self::frame::Frame;
use self::proto::{hash_key, OpCode}; use self::proto::{hash_key, OpCode};
pub use self::proto::CloseCode; pub use self::proto::CloseCode;
pub use self::context::WebsocketContext; pub use self::context::WebsocketContext;
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake}; pub use self::client::{WsClient, WsClientError,
WsClientReader, WsClientWriter, WsClientHandshake};
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";
@ -74,6 +75,94 @@ const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION";
// const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL"; // const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL";
/// Websocket errors
#[derive(Fail, Debug)]
pub enum WsError {
/// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")]
UnmaskedFrame,
/// Received a masked frame from server
#[fail(display="Received a masked frame from server")]
MaskedFrame,
/// Encountered invalid opcode
#[fail(display="Invalid opcode: {}", _0)]
InvalidOpcode(u8),
/// Invalid control frame length
#[fail(display="Invalid control frame length: {}", _0)]
InvalidLength(usize),
/// Bad web socket op code
#[fail(display="Bad web socket op code")]
BadOpCode,
/// A payload reached size limit.
#[fail(display="A payload reached size limit.")]
Overflow,
/// Continuation is not supproted
#[fail(display="Continuation is not supproted.")]
NoContinuation,
/// Bad utf-8 encoding
#[fail(display="Bad utf-8 encoding.")]
BadEncoding,
/// Payload error
#[fail(display="Payload error: {}", _0)]
Payload(#[cause] PayloadError),
}
impl ResponseError for WsError {}
impl From<PayloadError> for WsError {
fn from(err: PayloadError) -> WsError {
WsError::Payload(err)
}
}
/// Websocket handshake errors
#[derive(Fail, PartialEq, Debug)]
pub enum WsHandshakeError {
/// Only get method is allowed
#[fail(display="Method not allowed")]
GetMethodRequired,
/// Upgrade header if not set to websocket
#[fail(display="Websocket upgrade is expected")]
NoWebsocketUpgrade,
/// Connection header is not set to upgrade
#[fail(display="Connection upgrade is expected")]
NoConnectionUpgrade,
/// Websocket version header is not set
#[fail(display="Websocket version header is required")]
NoVersionHeader,
/// Unsupported websocket version
#[fail(display="Unsupported version")]
UnsupportedVersion,
/// Websocket key is not set or wrong
#[fail(display="Unknown websocket key")]
BadWebsocketKey,
}
impl ResponseError for WsHandshakeError {
fn error_response(&self) -> HttpResponse {
match *self {
WsHandshakeError::GetMethodRequired => {
HTTPMethodNotAllowed
.build()
.header(header::ALLOW, "GET")
.finish()
.unwrap()
}
WsHandshakeError::NoWebsocketUpgrade =>
HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"),
WsHandshakeError::NoConnectionUpgrade =>
HTTPBadRequest.with_reason("No CONNECTION upgrade"),
WsHandshakeError::NoVersionHeader =>
HTTPBadRequest.with_reason("Websocket version header is required"),
WsHandshakeError::UnsupportedVersion =>
HTTPBadRequest.with_reason("Unsupported version"),
WsHandshakeError::BadWebsocketKey =>
HTTPBadRequest.with_reason("Handshake error"),
}
}
}
/// `WebSocket` Message /// `WebSocket` Message
#[derive(Debug, PartialEq, Message)] #[derive(Debug, PartialEq, Message)]
pub enum Message { pub enum Message {
@ -82,19 +171,18 @@ pub enum Message {
Ping(String), Ping(String),
Pong(String), Pong(String),
Close(CloseCode), Close(CloseCode),
Error
} }
/// Do websocket handshake and start actor /// Do websocket handshake and start actor
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error> pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=WebsocketContext<A, S>> + Handler<Message>, where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, WsError>,
S: 'static S: 'static
{ {
let mut resp = handshake(&req)?; let mut resp = handshake(&req)?;
let stream = WsStream::new(req.clone()); let stream = WsStream::new(req.clone());
let mut ctx = WebsocketContext::new(req, actor); let mut ctx = WebsocketContext::new(req, actor);
ctx.add_message_stream(stream); ctx.add_stream(stream);
Ok(resp.body(ctx)?) Ok(resp.body(ctx)?)
} }
@ -168,33 +256,52 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, WsHands
pub struct WsStream<S> { pub struct WsStream<S> {
rx: PayloadHelper<S>, rx: PayloadHelper<S>,
closed: bool, closed: bool,
max_size: usize,
} }
impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
/// Create new websocket frames stream
pub fn new(stream: S) -> WsStream<S> { pub fn new(stream: S) -> WsStream<S> {
WsStream { rx: PayloadHelper::new(stream), WsStream { rx: PayloadHelper::new(stream),
closed: false } closed: false,
max_size: 65_536,
}
}
/// Set max frame size
///
/// By default max size is set to 64kb
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
} }
} }
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
type Item = Message; type Item = Message;
type Error = (); type Error = WsError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed { if self.closed {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None))
} }
match Frame::parse(&mut self.rx, true) { match Frame::parse(&mut self.rx, true, self.max_size) {
Ok(Async::Ready(Some(frame))) => { Ok(Async::Ready(Some(frame))) => {
// trace!("WsFrame {}", frame); let (finished, opcode, payload) = frame.unpack();
let (_finished, opcode, payload) = frame.unpack();
// continuation is not supported
if !finished {
self.closed = true;
return Err(WsError::NoContinuation)
}
match opcode { match opcode {
OpCode::Continue => unimplemented!(), OpCode::Continue => unimplemented!(),
OpCode::Bad => OpCode::Bad => {
Ok(Async::Ready(Some(Message::Error))), self.closed = true;
Err(WsError::BadOpCode)
}
OpCode::Close => { OpCode::Close => {
self.closed = true; self.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
@ -215,17 +322,19 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
match String::from_utf8(tmp) { match String::from_utf8(tmp) {
Ok(s) => Ok(s) =>
Ok(Async::Ready(Some(Message::Text(s)))), Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => Err(_) => {
Ok(Async::Ready(Some(Message::Error))), self.closed = true;
Err(WsError::BadEncoding)
}
} }
} }
} }
} }
Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => { Err(e) => {
self.closed = true; self.closed = true;
Ok(Async::Ready(Some(Message::Error))) Err(e)
} }
} }
} }
@ -306,4 +415,20 @@ mod tests {
assert_eq!(StatusCode::SWITCHING_PROTOCOLS, assert_eq!(StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().unwrap().status()); handshake(&req).unwrap().finish().unwrap().status());
} }
#[test]
fn test_wserror_http_response() {
let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
} }

View File

@ -16,8 +16,7 @@ impl Actor for Ws {
type Context = ws::WebsocketContext<Self>; type Context = ws::WebsocketContext<Self>;
} }
impl Handler<ws::Message> for Ws { impl StreamHandler<ws::Message, ws::WsError> for Ws {
type Result = ();
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {