From 644f1a951893bd0b1cc10d29cc97d542dc453878 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 26 Feb 2018 13:58:23 -0800 Subject: [PATCH] refactor ws frame parser --- Cargo.toml | 1 - src/client/parser.rs | 6 +- src/error.rs | 30 +++++- src/handler.rs | 1 - src/lib.rs | 6 +- src/multipart.rs | 4 +- src/payload.rs | 35 ++++++- src/ws/client.rs | 242 ++++++++++++++++++------------------------- src/ws/frame.rs | 147 +++++++++++++++----------- src/ws/mod.rs | 139 +++++++++---------------- tests/test_ws.rs | 3 +- 11 files changed, 304 insertions(+), 310 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d8a4b93a9..b7999b743 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,7 +77,6 @@ tokio-tls = { version="0.1", optional = true } openssl = { version="0.10", optional = true } tokio-openssl = { version="0.2", optional = true } -backtrace="*" [dependencies.actix] version = "0.5" diff --git a/src/client/parser.rs b/src/client/parser.rs index 03ba23f99..6a0ee1080 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -151,7 +151,9 @@ impl HttpResponseParser { } } - let decoder = if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { + let decoder = if status == StatusCode::SWITCHING_PROTOCOLS { + Some(Decoder::eof()) + } else if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { // Content-Length if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { @@ -167,8 +169,6 @@ impl HttpResponseParser { } else if chunked(&hdrs)? { // Chunked encoding Some(Decoder::chunked()) - } else if hdrs.contains_key(header::UPGRADE) { - Some(Decoder::eof()) } else { None }; diff --git a/src/error.rs b/src/error.rs index 458850a0f..f94362031 100644 --- a/src/error.rs +++ b/src/error.rs @@ -236,8 +236,6 @@ pub enum PayloadError { impl From for PayloadError { fn from(err: IoError) -> PayloadError { - use backtrace; - println!("IO ERROR {:?}", backtrace::Backtrace::new()); PayloadError::Io(err) } } @@ -391,6 +389,34 @@ impl ResponseError for WsHandshakeError { } } +/// Websocket errors +#[derive(Fail, Debug)] +pub enum WsError { + /// Received an unmasked frame from client + #[fail(display="Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[fail(display="Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[fail(display="Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[fail(display="Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Payload error + #[fail(display="Payload error: {}", _0)] + Payload(#[cause] PayloadError), +} + +impl ResponseError for WsError {} + +impl From for WsError { + fn from(err: PayloadError) -> WsError { + WsError::Payload(err) + } +} + /// A set of errors that can occur during parsing urlencoded payloads #[derive(Fail, Debug)] pub enum UrlencodedError { diff --git a/src/handler.rs b/src/handler.rs index 857c95398..253024e61 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -395,7 +395,6 @@ impl Handler for NormalizePath { } } } else if p.ends_with('/') { - println!("=== {:?}", p); // try to remove trailing slash let p = p.as_ref().trim_right_matches('/'); if router.has_route(p) { diff --git a/src/lib.rs b/src/lib.rs index 23c3f4982..7ec675657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ //! * Supported *HTTP/1.x* and *HTTP/2.0* protocols //! * Streaming and pipelining //! * Keep-alive and slow requests handling -//! * WebSockets server/client +//! * *WebSockets* server/client //! * Transparent content compression/decompression (br, gzip, deflate) //! * Configurable request routing //! * Multipart streams @@ -44,7 +44,7 @@ specialization, // for impl ErrorResponse for std::error::Error ))] #![cfg_attr(feature = "cargo-clippy", allow( - decimal_literal_representation,))] + decimal_literal_representation,suspicious_arithmetic_impl,))] #[macro_use] extern crate log; @@ -97,8 +97,6 @@ extern crate openssl; #[cfg(feature="openssl")] extern crate tokio_openssl; -extern crate backtrace; - mod application; mod body; mod context; diff --git a/src/multipart.rs b/src/multipart.rs index d3790a9a4..457fe4beb 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -492,10 +492,10 @@ impl InnerField where S: Stream { if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--" && &chunk[4..] == boundary.as_bytes() { - payload.unread_data(chunk); + payload.unread_data(chunk.freeze()); Ok(Async::Ready(None)) } else { - Ok(Async::Ready(Some(chunk))) + Ok(Async::Ready(Some(chunk.freeze()))) } } } diff --git a/src/payload.rs b/src/payload.rs index 4cc0eaf68..e4ba819d1 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -297,9 +297,9 @@ impl Inner { buf.extend_from_slice(&chunk.split_to(rem)); if !chunk.is_empty() { self.items.push_front(chunk); - return Ok(Async::Ready(buf.freeze())) } } + return Ok(Async::Ready(buf.freeze())) } if let Some(err) = self.err.take() { @@ -423,7 +423,7 @@ impl PayloadHelper where S: Stream { PayloadHelper { len: 0, items: VecDeque::new(), - stream: stream, + stream, } } @@ -445,6 +445,10 @@ impl PayloadHelper where S: Stream { self.len } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn readany(&mut self) -> Poll, PayloadError> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); @@ -458,7 +462,7 @@ impl PayloadHelper where S: Stream { } } - pub fn readexactly(&mut self, size: usize) -> Poll, PayloadError> { + pub fn readexactly(&mut self, size: usize) -> Poll, PayloadError> { if size <= self.len { let mut buf = BytesMut::with_capacity(size); while buf.len() < size { @@ -468,13 +472,34 @@ impl PayloadHelper where S: Stream { buf.extend_from_slice(&chunk.split_to(rem)); if !chunk.is_empty() { self.items.push_front(chunk); - return Ok(Async::Ready(Some(buf.freeze()))) + } + } + return Ok(Async::Ready(Some(buf))) + } + + match self.poll_stream()? { + Async::Ready(true) => self.readexactly(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + + pub fn copy(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + let mut buf = BytesMut::with_capacity(size); + for chunk in &self.items { + if buf.len() < size { + let rem = cmp::min(size - buf.len(), chunk.len()); + buf.extend_from_slice(&chunk[..rem]); + } + if buf.len() == size { + return Ok(Async::Ready(Some(buf))) } } } match self.poll_stream()? { - Async::Ready(true) => self.readexactly(size), + Async::Ready(true) => self.copy(size), Async::Ready(false) => Ok(Async::Ready(None)), Async::NotReady => Ok(Async::NotReady), } diff --git a/src/ws/client.rs b/src/ws/client.rs index 2023fdd21..369b56315 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -8,24 +8,28 @@ use std::cell::UnsafeCell; use base64; use rand; use cookie::Cookie; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use http::{HttpTryFrom, StatusCode, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use sha1::Sha1; use futures::{Async, Future, Poll, Stream}; use futures::future::{Either, err as FutErr}; +use futures::unsync::mpsc::{unbounded, UnboundedSender}; use tokio_core::net::TcpStream; +use byteorder::{ByteOrder, NetworkEndian}; use actix::prelude::*; -use body::Binary; -use error::UrlParseError; +use body::{Body, Binary}; +use error::{WsError, UrlParseError}; +use payload::PayloadHelper; use server::shared::SharedBytes; use server::{utils, IoStream}; -use client::{ClientRequest, ClientRequestBuilder, +use client::{ClientRequest, ClientRequestBuilder, ClientResponse, HttpResponseParser, HttpResponseParserError, HttpClientWriter}; -use client::{Connect, Connection, ClientConnector, ClientConnectorError}; +use client::{Connect, Connection, ClientConnector, ClientConnectorError, + SendRequest, SendRequestError}; use super::Message; use super::frame::Frame; @@ -52,7 +56,9 @@ pub enum WsClientError { #[fail(display="Response parsing error")] ResponseParseError(HttpResponseParserError), #[fail(display="{}", _0)] - Connector(ClientConnectorError), + SendRequest(SendRequestError), + #[fail(display="{}", _0)] + Protocol(#[cause] WsError), #[fail(display="{}", _0)] Io(io::Error), #[fail(display="Disconnected")] @@ -71,9 +77,15 @@ impl From for WsClientError { } } -impl From for WsClientError { - fn from(err: ClientConnectorError) -> WsClientError { - WsClientError::Connector(err) +impl From for WsClientError { + fn from(err: SendRequestError) -> WsClientError { + WsClientError::SendRequest(err) + } +} + +impl From for WsClientError { + fn from(err: WsError) -> WsClientError { + WsClientError::Protocol(err) } } @@ -206,21 +218,17 @@ impl WsClient { } struct WsInner { - conn: Connection, - writer: HttpClientWriter, - parser: HttpResponseParser, - parser_buf: BytesMut, + tx: UnboundedSender, + rx: PayloadHelper, closed: bool, - error_sent: bool, } pub struct WsHandshake { inner: Option, - request: Option, - sent: bool, + request: Option, + tx: Option>, key: String, error: Option, - stream: Option, Error=WsClientError>>>, } impl WsHandshake { @@ -235,31 +243,29 @@ impl WsHandshake { let key = base64::encode(&sec_key); if let Some(mut request) = request { - let stream = Box::new( - conn.send(Connect(request.uri().clone())) - .map(|res| res.map_err(|e| e.into())) - .map_err(|_| WsClientError::Disconnected)); - request.headers_mut().insert( HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), HeaderValue::try_from(key.as_str()).unwrap()); + let (tx, rx) = unbounded(); + request.set_body(Body::Streaming( + Box::new(rx.map_err(|_| io::Error::new( + io::ErrorKind::Other, "disconnected").into())))); + WsHandshake { key: key, inner: None, - request: Some(request), - sent: false, + request: Some(request.with_connector(conn.clone())), + tx: Some(tx), error: err, - stream: Some(stream), } } else { WsHandshake { key: key, inner: None, request: None, - sent: false, + tx: None, error: err, - stream: None, } } } @@ -274,94 +280,67 @@ impl Future for WsHandshake { return Err(err) } - if self.stream.is_some() { - match self.stream.as_mut().unwrap().poll()? { - Async::Ready(result) => match result { - Ok(conn) => { - let inner = WsInner { - conn: conn, - writer: HttpClientWriter::new(SharedBytes::default()), - parser: HttpResponseParser::default(), - parser_buf: BytesMut::new(), - closed: false, - error_sent: false, - }; - self.stream.take(); - self.inner = Some(inner); - } - Err(err) => return Err(err), - }, - Async::NotReady => return Ok(Async::NotReady) + let resp = match self.request.as_mut().unwrap().poll()? { + Async::Ready(response) => { + self.request.take(); + response + }, + Async::NotReady => return Ok(Async::NotReady) + }; + + // verify response + if resp.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus) + } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false } + } else { + false + }; + if !has_hdr { + return Err(WsClientError::InvalidUpgradeHeader) + } + // Check for "CONNECTION" header + let has_hdr = if let Some(conn) = resp.headers().get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + s.to_lowercase().contains("upgrade") + } else { false } + } else { false }; + if !has_hdr { + return Err(WsClientError::InvalidConnectionHeader) } - let mut inner = self.inner.take().unwrap(); - - if !self.sent { - self.sent = true; - inner.writer.start(self.request.as_mut().unwrap())?; - } - if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { - return Err(err.into()) + let match_key = if let Some(key) = resp.headers().get( + HeaderName::try_from("SEC-WEBSOCKET-ACCEPT").unwrap()) + { + // field is constructed by concatenating /key/ + // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(self.key.as_ref()); + sha1.update(WS_GUID); + key.as_bytes() == base64::encode(&sha1.digest().bytes()).as_bytes() + } else { + false + }; + if !match_key { + return Err(WsClientError::InvalidChallengeResponse) } - match inner.parser.parse(&mut inner.conn, &mut inner.parser_buf) { - Ok(Async::Ready(resp)) => { - // verify response - if resp.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(WsClientError::InvalidResponseStatus) - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - return Err(WsClientError::InvalidUpgradeHeader) - } - // Check for "CONNECTION" header - let has_hdr = if let Some(conn) = resp.headers().get(header::CONNECTION) { - if let Ok(s) = conn.to_str() { - s.to_lowercase().contains("upgrade") - } else { false } - } else { false }; - if !has_hdr { - return Err(WsClientError::InvalidConnectionHeader) - } + let inner = WsInner { + tx: self.tx.take().unwrap(), + rx: PayloadHelper::new(resp), + closed: false, + }; - let match_key = if let Some(key) = resp.headers().get( - HeaderName::try_from("SEC-WEBSOCKET-ACCEPT").unwrap()) - { - // field is constructed by concatenating /key/ - // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut sha1 = Sha1::new(); - sha1.update(self.key.as_ref()); - sha1.update(WS_GUID); - key.as_bytes() == base64::encode(&sha1.digest().bytes()).as_bytes() - } else { - false - }; - if !match_key { - return Err(WsClientError::InvalidChallengeResponse) - } - - let inner = Rc::new(UnsafeCell::new(inner)); - Ok(Async::Ready( - (WsClientReader{inner: Rc::clone(&inner)}, - WsClientWriter{inner: inner}))) - }, - Ok(Async::NotReady) => { - self.inner = Some(inner); - Ok(Async::NotReady) - }, - Err(err) => Err(err.into()) - } + let inner = Rc::new(UnsafeCell::new(inner)); + Ok(Async::Ready( + (WsClientReader{inner: Rc::clone(&inner)}, WsClientWriter{inner: inner}))) } } @@ -389,24 +368,13 @@ impl Stream for WsClientReader { fn poll(&mut self) -> Poll, Self::Error> { let inner = self.as_mut(); - let mut done = false; - - match utils::read_from_io(&mut inner.conn, &mut inner.parser_buf) { - Ok(Async::Ready(0)) => { - done = true; - inner.closed = true; - }, - Ok(Async::Ready(_)) | Ok(Async::NotReady) => (), - Err(err) => - return Err(err.into()) + if inner.closed { + return Ok(Async::Ready(None)) } - // write - let _ = inner.writer.poll_completed(&mut inner.conn, false); - // read - match Frame::parse(&mut inner.parser_buf, false) { - Ok(Some(frame)) => { + match Frame::parse(&mut inner.rx, false) { + Ok(Async::Ready(Some(frame))) => { // trace!("WsFrame {}", frame); let (_finished, opcode, payload) = frame.unpack(); @@ -416,8 +384,9 @@ impl Stream for WsClientReader { Ok(Async::Ready(Some(Message::Error))), OpCode::Close => { inner.closed = true; - inner.error_sent = true; - Ok(Async::Ready(Some(Message::Closed))) + let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; + Ok(Async::Ready( + Some(Message::Close(CloseCode::from(code))))) }, OpCode::Ping => Ok(Async::Ready(Some( @@ -440,23 +409,10 @@ impl Stream for WsClientReader { } } } - Ok(None) => { - if done { - Ok(Async::Ready(None)) - } else if inner.closed { - if !inner.error_sent { - inner.error_sent = true; - Ok(Async::Ready(Some(Message::Closed))) - } else { - Ok(Async::Ready(None)) - } - } else { - Ok(Async::NotReady) - } - }, + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { inner.closed = true; - inner.error_sent = true; Err(err.into()) } } @@ -478,9 +434,9 @@ impl WsClientWriter { /// Write payload #[inline] - fn write(&mut self, data: Binary) { + fn write(&mut self, mut data: Binary) { if !self.as_mut().closed { - let _ = self.as_mut().writer.write(data); + let _ = self.as_mut().tx.unbounded_send(data.take()); } else { warn!("Trying to write to disconnected response"); } diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 8771435fa..22067cb60 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -1,11 +1,13 @@ use std::{fmt, mem}; -use std::io::{Error, ErrorKind}; use std::iter::FromIterator; -use bytes::{BytesMut, BufMut}; +use bytes::{Bytes, BytesMut, BufMut}; use byteorder::{ByteOrder, BigEndian, NetworkEndian}; +use futures::{Async, Poll, Stream}; use rand; use body::Binary; +use error::{WsError, PayloadError}; +use payload::PayloadHelper; use ws::proto::{OpCode, CloseCode}; use ws::mask::apply_mask; @@ -48,14 +50,15 @@ impl Frame { } /// Parse the input stream into a frame. - pub fn parse(buf: &mut BytesMut, server: bool) -> Result, Error> { + pub fn parse(pl: &mut PayloadHelper, server: bool) -> Poll, WsError> + where S: Stream + { let mut idx = 2; - let mut size = buf.len(); - - if size < 2 { - return Ok(None) - } - size -= 2; + let buf = match pl.copy(2)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let first = buf[0]; let second = buf[1]; let finished = first & 0x80 != 0; @@ -63,11 +66,9 @@ impl Frame { // check masking let masked = second & 0x80 != 0; if !masked && server { - return Err(Error::new( - ErrorKind::Other, "Received an unmasked frame from client")) + return Err(WsError::UnmaskedFrame) } else if masked && !server { - return Err(Error::new( - ErrorKind::Other, "Received a masked frame from server")) + return Err(WsError::MaskedFrame) } let rsv1 = first & 0x40 != 0; @@ -77,19 +78,21 @@ impl Frame { let len = second & 0x7F; let length = if len == 126 { - if size < 2 { - return Ok(None) - } + let buf = match pl.copy(4)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let len = NetworkEndian::read_uint(&buf[idx..], 2) as usize; - size -= 2; idx += 2; len } else if len == 127 { - if size < 8 { - return Ok(None) - } + let buf = match pl.copy(10)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let len = NetworkEndian::read_uint(&buf[idx..], 8) as usize; - size -= 8; idx += 8; len } else { @@ -97,50 +100,42 @@ impl Frame { }; let mask = if server { - if size < 4 { - return Ok(None) - } else { - let mut mask_bytes = [0u8; 4]; - size -= 4; - mask_bytes.copy_from_slice(&buf[idx..idx+4]); - idx += 4; - Some(mask_bytes) - } + let buf = match pl.copy(idx + 4)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; + + let mut mask_bytes = [0u8; 4]; + mask_bytes.copy_from_slice(&buf[idx..idx+4]); + idx += 4; + Some(mask_bytes) } else { None }; - if size < length { - return Ok(None) - } + let mut data = match pl.readexactly(idx + length)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; // get body - buf.split_to(idx); - let mut data = if length > 0 { - buf.split_to(length) - } else { - BytesMut::new() - }; + data.split_to(idx); // Disallow bad opcode if let OpCode::Bad = opcode { - return Err( - Error::new( - ErrorKind::Other, - format!("Encountered invalid opcode: {}", first & 0x0F))) + return Err(WsError::InvalidOpcode(first & 0x0F)) } // control frames must have length <= 125 match opcode { OpCode::Ping | OpCode::Pong if length > 125 => { - return Err( - Error::new( - ErrorKind::Other, - format!("Rejected WebSocket handshake.Received control frame with length: {}.", length))) + return Err(WsError::InvalidLength(length)) } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some(Frame::default())) + return Ok(Async::Ready(Some(Frame::default()))) } _ => () } @@ -150,14 +145,14 @@ impl Frame { apply_mask(&mut data, mask); } - Ok(Some(Frame { + Ok(Async::Ready(Some(Frame { finished: finished, rsv1: rsv1, rsv2: rsv2, rsv3: rsv3, opcode: opcode, payload: data.into(), - })) + }))) } /// Generate binary representation @@ -258,13 +253,33 @@ impl fmt::Display for Frame { #[cfg(test)] mod tests { use super::*; + use futures::stream::once; + + fn is_none(frm: Poll, WsError>) -> bool { + match frm { + Ok(Async::Ready(None)) => true, + _ => false, + } + } + + fn extract(frm: Poll, WsError>) -> Frame { + match frm { + Ok(Async::Ready(Some(frame))) => frame, + _ => panic!("error"), + } + } #[test] fn test_parse() { + let mut buf = PayloadHelper::new( + once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze()))); + assert!(is_none(Frame::parse(&mut buf, false))); + let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(b"1"); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + let frame = extract(Frame::parse(&mut buf, false)); println!("FRAME: {}", frame); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); @@ -273,8 +288,10 @@ mod tests { #[test] fn test_parse_length0() { - let mut buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + let frame = extract(Frame::parse(&mut buf, false)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert!(frame.payload.is_empty()); @@ -282,12 +299,16 @@ mod tests { #[test] fn test_parse_length2() { + let buf = BytesMut::from(&[0b00000001u8, 126u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + assert!(is_none(Frame::parse(&mut buf, false))); + let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(&[0u8, 4u8][..]); buf.extend(b"1234"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -295,12 +316,16 @@ mod tests { #[test] fn test_parse_length4() { + let buf = BytesMut::from(&[0b00000001u8, 127u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + assert!(is_none(Frame::parse(&mut buf, false))); + let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(b"1234"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -311,10 +336,11 @@ mod tests { let mut buf = BytesMut::from(&[0b00000001u8, 0b10000001u8][..]); buf.extend(b"0001"); buf.extend(b"1"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); assert!(Frame::parse(&mut buf, false).is_err()); - let frame = Frame::parse(&mut buf, true).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, true)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); @@ -324,10 +350,11 @@ mod tests { fn test_parse_frame_no_mask() { let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); buf.extend(&[1u8]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); assert!(Frame::parse(&mut buf, true).is_err()); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 93d0b61aa..465400ac9 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -43,15 +43,16 @@ //! # .finish(); //! # } //! ``` -use bytes::BytesMut; +use bytes::Bytes; use http::{Method, StatusCode, header}; use futures::{Async, Poll, Stream}; +use byteorder::{ByteOrder, NetworkEndian}; use actix::{Actor, AsyncContext, Handler}; use body::Binary; -use payload::Payload; -use error::{Error, WsHandshakeError}; +use payload::PayloadHelper; +use error::{Error, WsHandshakeError, PayloadError}; use httprequest::HttpRequest; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; @@ -80,8 +81,7 @@ pub enum Message { Binary(Binary), Ping(String), Pong(String), - Close, - Closed, + Close(CloseCode), Error } @@ -165,104 +165,67 @@ pub fn handshake(req: &HttpRequest) -> Result { + rx: PayloadHelper, closed: bool, - error_sent: bool, } -impl WsStream { - pub fn new(payload: Payload) -> WsStream { - WsStream { rx: payload, - buf: BytesMut::new(), - closed: false, - error_sent: false } +impl WsStream where S: Stream { + pub fn new(stream: S) -> WsStream { + WsStream { rx: PayloadHelper::new(stream), + closed: false } } } -impl Stream for WsStream { +impl Stream for WsStream where S: Stream { type Item = Message; type Error = (); fn poll(&mut self) -> Poll, Self::Error> { - let mut done = false; + if self.closed { + return Ok(Async::Ready(None)) + } - if !self.closed { - loop { - match self.rx.poll() { - Ok(Async::Ready(Some(chunk))) => { - self.buf.extend_from_slice(&chunk) - } - Ok(Async::Ready(None)) => { - done = true; + match Frame::parse(&mut self.rx, true) { + Ok(Async::Ready(Some(frame))) => { + // trace!("WsFrame {}", frame); + let (_finished, opcode, payload) = frame.unpack(); + + match opcode { + OpCode::Continue => unimplemented!(), + OpCode::Bad => + Ok(Async::Ready(Some(Message::Error))), + OpCode::Close => { self.closed = true; - break; - } - Ok(Async::NotReady) => break, - Err(_) => { - self.closed = true; - break; + let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; + Ok(Async::Ready( + Some(Message::Close(CloseCode::from(code))))) + }, + OpCode::Ping => + Ok(Async::Ready(Some( + Message::Ping( + String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Pong => + Ok(Async::Ready(Some( + Message::Pong(String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Binary => + Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Text => { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => + Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => + Ok(Async::Ready(Some(Message::Error))), + } } } } - } - - loop { - match Frame::parse(&mut self.buf, true) { - Ok(Some(frame)) => { - // trace!("WsFrame {}", frame); - let (_finished, opcode, payload) = frame.unpack(); - - match opcode { - OpCode::Continue => continue, - OpCode::Bad => - return Ok(Async::Ready(Some(Message::Error))), - OpCode::Close => { - self.closed = true; - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Closed))) - }, - OpCode::Ping => - return Ok(Async::Ready(Some( - Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Pong => - return Ok(Async::Ready(Some( - Message::Pong( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Binary => - return Ok(Async::Ready(Some(Message::Binary(payload)))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => - return Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => - return Ok(Async::Ready(Some(Message::Error))), - } - } - } - } - Ok(None) => { - if done { - return Ok(Async::Ready(None)) - } else if self.closed { - if !self.error_sent { - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Closed))) - } else { - return Ok(Async::Ready(None)) - } - } else { - return Ok(Async::NotReady) - } - }, - Err(_) => { - self.closed = true; - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Error))); - } + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => { + self.closed = true; + Ok(Async::Ready(Some(Message::Error))) } } } diff --git a/tests/test_ws.rs b/tests/test_ws.rs index ac7119914..8b7d0111f 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -24,6 +24,7 @@ impl Handler for Ws { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), + ws::Message::Close(reason) => ctx.close(reason, ""), _ => (), } } @@ -49,5 +50,5 @@ fn test_simple() { writer.close(ws::CloseCode::Normal, ""); let (item, _) = srv.execute(reader.into_future()).unwrap(); - assert!(item.is_none()); + assert_eq!(item, Some(ws::Message::Close(ws::CloseCode::Normal))); }