From 5c0a2066cc901c136062121345684332a2e98ac3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 5 Oct 2018 12:47:22 -0700 Subject: [PATCH] refactor ws to a websocket codec --- src/body.rs | 2 - src/ws/codec.rs | 119 +++++++++++++++ src/ws/frame.rs | 382 ++++++++++++++---------------------------------- src/ws/mod.rs | 133 ++--------------- 4 files changed, 238 insertions(+), 398 deletions(-) create mode 100644 src/ws/codec.rs diff --git a/src/body.rs b/src/body.rs index db06bef22..c10b067a2 100644 --- a/src/body.rs +++ b/src/body.rs @@ -17,8 +17,6 @@ pub enum Body { /// Unspecified streaming response. Developer is responsible for setting /// right `Content-Length` or `Transfer-Encoding` headers. Streaming(BodyStream), - // /// Special body type for actor response. - // Actor(Box), } /// Represents various types of binary body. diff --git a/src/ws/codec.rs b/src/ws/codec.rs new file mode 100644 index 000000000..6e2b12090 --- /dev/null +++ b/src/ws/codec.rs @@ -0,0 +1,119 @@ +use bytes::BytesMut; +use tokio_codec::{Decoder, Encoder}; + +use super::frame::Frame; +use super::proto::{CloseReason, OpCode}; +use super::ProtocolError; +use body::Binary; + +/// `WebSocket` Message +#[derive(Debug, PartialEq)] +pub enum Message { + /// Text message + Text(String), + /// Binary message + Binary(Binary), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), +} + +/// WebSockets protocol codec +pub struct Codec { + max_size: usize, + server: bool, +} + +impl Codec { + /// Create new websocket frames decoder + pub fn new() -> Codec { + Codec { + max_size: 65_536, + server: true, + } + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + pub fn client_mode(mut self) -> Self { + self.server = false; + self + } +} + +impl Encoder for Codec { + type Item = Message; + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + Message::Text(txt) => { + Frame::write_message(dst, txt, OpCode::Text, true, !self.server) + } + Message::Binary(bin) => { + Frame::write_message(dst, bin, OpCode::Binary, true, !self.server) + } + Message::Ping(txt) => { + Frame::write_message(dst, txt, OpCode::Ping, true, !self.server) + } + Message::Pong(txt) => { + Frame::write_message(dst, txt, OpCode::Pong, true, !self.server) + } + Message::Close(reason) => Frame::write_close(dst, reason, !self.server), + } + Ok(()) + } +} + +impl Decoder for Codec { + type Item = Message; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Frame::parse(src, self.server, self.max_size) { + Ok(Some((finished, opcode, payload))) => { + // continuation is not supported + if !finished { + return Err(ProtocolError::NoContinuation); + } + + match opcode { + OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Bad => Err(ProtocolError::BadOpCode), + OpCode::Close => { + let close_reason = Frame::parse_close_payload(&payload); + Ok(Some(Message::Close(close_reason))) + } + OpCode::Ping => Ok(Some(Message::Ping( + String::from_utf8_lossy(payload.as_ref()).into(), + ))), + OpCode::Pong => Ok(Some(Message::Pong( + String::from_utf8_lossy(payload.as_ref()).into(), + ))), + OpCode::Binary => Ok(Some(Message::Binary(payload))), + OpCode::Text => { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => Ok(Some(Message::Text(s))), + Err(_) => Err(ProtocolError::BadEncoding), + } + } + } + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } +} diff --git a/src/ws/frame.rs b/src/ws/frame.rs index d5fa98272..38bebc283 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -1,144 +1,29 @@ use byteorder::{ByteOrder, LittleEndian, NetworkEndian}; -use bytes::{BufMut, Bytes, BytesMut}; -use futures::{Async, Poll, Stream}; +use bytes::{BufMut, BytesMut}; use rand; -use std::fmt; use body::Binary; -use error::PayloadError; -use payload::PayloadBuffer; - use ws::mask::apply_mask; use ws::proto::{CloseCode, CloseReason, OpCode}; use ws::ProtocolError; /// A struct representing a `WebSocket` frame. #[derive(Debug)] -pub struct Frame { - finished: bool, - opcode: OpCode, - payload: Binary, -} +pub struct Frame; impl Frame { - /// Destruct frame - pub fn unpack(self) -> (bool, OpCode, Binary) { - (self.finished, self.opcode, self.payload) - } - - /// Create a new Close control frame. - #[inline] - pub fn close(reason: Option, genmask: bool) -> FramedMessage { - let payload = match reason { - None => Vec::new(), - Some(reason) => { - let mut code_bytes = [0; 2]; - NetworkEndian::write_u16(&mut code_bytes, reason.code.into()); - - let mut payload = Vec::from(&code_bytes[..]); - if let Some(description) = reason.description { - payload.extend(description.as_bytes()); - } - payload - } - }; - - Frame::message(payload, OpCode::Close, true, genmask) - } - - #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] - fn read_copy_md( - pl: &mut PayloadBuffer, server: bool, max_size: usize, - ) -> Poll)>, ProtocolError> - where - S: Stream, - { - let mut idx = 2; - let buf = match pl.copy(2)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let first = buf[0]; - let second = buf[1]; - let finished = first & 0x80 != 0; - - // check masking - let masked = second & 0x80 != 0; - if !masked && server { - return Err(ProtocolError::UnmaskedFrame); - } else if masked && !server { - return Err(ProtocolError::MaskedFrame); - } - - // Op code - let opcode = OpCode::from(first & 0x0F); - - if let OpCode::Bad = opcode { - return Err(ProtocolError::InvalidOpcode(first & 0x0F)); - } - - let len = second & 0x7F; - let length = if len == 126 { - let buf = match pl.copy(4)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let len = NetworkEndian::read_uint(&buf[idx..], 2) as usize; - idx += 2; - len - } else if len == 127 { - let buf = match pl.copy(10)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let len = NetworkEndian::read_uint(&buf[idx..], 8); - if len > max_size as u64 { - return Err(ProtocolError::Overflow); - } - idx += 8; - len as usize - } else { - len as usize - }; - - // check for max allowed size - if length > max_size { - return Err(ProtocolError::Overflow); - } - - let mask = if server { - let buf = match pl.copy(idx + 4)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - - let mask: &[u8] = &buf[idx..idx + 4]; - let mask_u32 = LittleEndian::read_u32(mask); - idx += 4; - Some(mask_u32) - } else { - None - }; - - Ok(Async::Ready(Some((idx, finished, opcode, length, mask)))) - } - - fn read_chunk_md( - chunk: &[u8], server: bool, max_size: usize, - ) -> Poll<(usize, bool, OpCode, usize, Option), ProtocolError> { - let chunk_len = chunk.len(); + fn parse_metadata( + src: &[u8], server: bool, max_size: usize, + ) -> Result)>, ProtocolError> { + let chunk_len = src.len(); let mut idx = 2; if chunk_len < 2 { - return Ok(Async::NotReady); + return Ok(None); } - let first = chunk[0]; - let second = chunk[1]; + let first = src[0]; + let second = src[1]; let finished = first & 0x80 != 0; // check masking @@ -159,16 +44,16 @@ impl Frame { let len = second & 0x7F; let length = if len == 126 { if chunk_len < 4 { - return Ok(Async::NotReady); + return Ok(None); } - let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize; + let len = NetworkEndian::read_uint(&src[idx..], 2) as usize; idx += 2; len } else if len == 127 { if chunk_len < 10 { - return Ok(Async::NotReady); + return Ok(None); } - let len = NetworkEndian::read_uint(&chunk[idx..], 8); + let len = NetworkEndian::read_uint(&src[idx..], 8); if len > max_size as u64 { return Err(ProtocolError::Overflow); } @@ -185,10 +70,10 @@ impl Frame { let mask = if server { if chunk_len < idx + 4 { - return Ok(Async::NotReady); + return Ok(None); } - let mask: &[u8] = &chunk[idx..idx + 4]; + let mask: &[u8] = &src[idx..idx + 4]; let mask_u32 = LittleEndian::read_u32(mask); idx += 4; Some(mask_u32) @@ -196,56 +81,34 @@ impl Frame { None }; - Ok(Async::Ready((idx, finished, opcode, length, mask))) + Ok(Some((idx, finished, opcode, length, mask))) } /// Parse the input stream into a frame. - pub fn parse( - pl: &mut PayloadBuffer, server: bool, max_size: usize, - ) -> Poll, ProtocolError> - where - S: Stream, - { - // try to parse ws frame md from one chunk - let result = match pl.get_chunk()? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::Ready(Some(chunk)) => Frame::read_chunk_md(chunk, server, max_size)?, - }; + pub fn parse( + src: &mut BytesMut, server: bool, max_size: usize, + ) -> Result, ProtocolError> { + // try to parse ws frame metadata + let (idx, finished, opcode, length, mask) = + match Frame::parse_metadata(src, server, max_size)? { + None => return Ok(None), + Some(res) => res, + }; - let (idx, finished, opcode, length, mask) = match result { - // we may need to join several chunks - Async::NotReady => match Frame::read_copy_md(pl, server, max_size)? { - Async::Ready(Some(item)) => item, - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(None)), - }, - Async::Ready(item) => item, - }; - - match pl.can_read(idx + length)? { - Async::Ready(Some(true)) => (), - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::Ready(Some(false)) | Async::NotReady => return Ok(Async::NotReady), + // not enough data + if src.len() < idx + length { + return Ok(None); } // remove prefix - pl.drop_bytes(idx); + src.split_to(idx); // no need for body if length == 0 { - return Ok(Async::Ready(Some(Frame { - finished, - opcode, - payload: Binary::from(""), - }))); + return Ok(Some((finished, opcode, Binary::from("")))); } - let data = match pl.read_exact(length)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => panic!(), - }; + let mut data = src.split_to(length); // control frames must have length <= 125 match opcode { @@ -254,26 +117,17 @@ impl Frame { } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Async::Ready(Some(Frame::default()))); + return Ok(Some((true, OpCode::Close, Binary::from("")))); } _ => (), } // unmask - let data = if let Some(mask) = mask { - let mut buf = BytesMut::new(); - buf.extend_from_slice(&data); - apply_mask(&mut buf, mask); - buf.freeze() - } else { - data - }; + if let Some(mask) = mask { + apply_mask(&mut data, mask); + } - Ok(Async::Ready(Some(Frame { - finished, - opcode, - payload: data.into(), - }))) + Ok(Some((finished, opcode, data.into()))) } /// Parse the payload of a close frame. @@ -293,120 +147,101 @@ impl Frame { } /// Generate binary representation - pub fn message>( - data: B, code: OpCode, finished: bool, genmask: bool, - ) -> FramedMessage { - let payload = data.into(); - let one: u8 = if finished { - 0x80 | Into::::into(code) + pub fn write_message>( + dst: &mut BytesMut, pl: B, op: OpCode, fin: bool, mask: bool, + ) { + let payload = pl.into(); + let one: u8 = if fin { + 0x80 | Into::::into(op) } else { - code.into() + op.into() }; let payload_len = payload.len(); - let (two, p_len) = if genmask { + let (two, p_len) = if mask { (0x80, payload_len + 4) } else { (0, payload_len) }; - let mut buf = if payload_len < 126 { - let mut buf = BytesMut::with_capacity(p_len + 2); - buf.put_slice(&[one, two | payload_len as u8]); - buf + if payload_len < 126 { + dst.put_slice(&[one, two | payload_len as u8]); } else if payload_len <= 65_535 { - let mut buf = BytesMut::with_capacity(p_len + 4); - buf.put_slice(&[one, two | 126]); - buf.put_u16_be(payload_len as u16); - buf + dst.reserve(p_len + 4); + dst.put_slice(&[one, two | 126]); + dst.put_u16_be(payload_len as u16); } else { - let mut buf = BytesMut::with_capacity(p_len + 10); - buf.put_slice(&[one, two | 127]); - buf.put_u64_be(payload_len as u64); - buf + dst.reserve(p_len + 10); + dst.put_slice(&[one, two | 127]); + dst.put_u64_be(payload_len as u64); }; - let binary = if genmask { + if mask { let mask = rand::random::(); - buf.put_u32_le(mask); - buf.extend_from_slice(payload.as_ref()); - let pos = buf.len() - payload_len; - apply_mask(&mut buf[pos..], mask); - buf.into() + dst.put_u32_le(mask); + dst.extend_from_slice(payload.as_ref()); + let pos = dst.len() - payload_len; + apply_mask(&mut dst[pos..], mask); } else { - buf.put_slice(payload.as_ref()); - buf.into() - }; - - FramedMessage(binary) - } -} - -impl Default for Frame { - fn default() -> Frame { - Frame { - finished: true, - opcode: OpCode::Close, - payload: Binary::from(&b""[..]), + dst.put_slice(payload.as_ref()); } } -} -impl fmt::Display for Frame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - " - - final: {} - opcode: {} - payload length: {} - payload: 0x{} -", - self.finished, - self.opcode, - self.payload.len(), - self.payload - .as_ref() - .iter() - .map(|byte| format!("{:x}", byte)) - .collect::() - ) + /// Create a new Close control frame. + #[inline] + pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { + let payload = match reason { + None => Vec::new(), + Some(reason) => { + let mut code_bytes = [0; 2]; + NetworkEndian::write_u16(&mut code_bytes, reason.code.into()); + + let mut payload = Vec::from(&code_bytes[..]); + if let Some(description) = reason.description { + payload.extend(description.as_bytes()); + } + payload + } + }; + + Frame::write_message(dst, payload, OpCode::Close, true, mask) } } -/// `WebSocket` message with framing. -#[derive(Debug)] -pub struct FramedMessage(pub(crate) Binary); - #[cfg(test)] mod tests { use super::*; - use futures::stream::once; - fn is_none(frm: &Poll, ProtocolError>) -> bool { + struct F { + finished: bool, + opcode: OpCode, + payload: Binary, + } + + fn is_none(frm: &Result, ProtocolError>) -> bool { match *frm { - Ok(Async::Ready(None)) => true, + Ok(None) => true, _ => false, } } - fn extract(frm: Poll, ProtocolError>) -> Frame { + fn extract(frm: Result, ProtocolError>) -> F { match frm { - Ok(Async::Ready(Some(frame))) => frame, + Ok(Some((finished, opcode, payload))) => F { + finished, + opcode, + payload, + }, _ => unreachable!("error"), } } #[test] fn test_parse() { - let mut buf = PayloadBuffer::new(once(Ok(BytesMut::from( - &[0b0000_0001u8, 0b0000_0001u8][..], - ).freeze()))); + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); assert!(is_none(&Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); buf.extend(b"1"); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); @@ -416,9 +251,7 @@ mod tests { #[test] fn test_parse_length0() { - let buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); - + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); @@ -427,14 +260,12 @@ mod tests { #[test] fn test_parse_length2() { - let buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); + let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); assert!(is_none(&Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); buf.extend(&[0u8, 4u8][..]); buf.extend(b"1234"); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); @@ -444,14 +275,12 @@ mod tests { #[test] fn test_parse_length4() { - let buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); + let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); assert!(is_none(&Frame::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(b"1234"); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); @@ -464,7 +293,6 @@ mod tests { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]); buf.extend(b"0001"); buf.extend(b"1"); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); assert!(Frame::parse(&mut buf, false, 1024).is_err()); @@ -478,7 +306,6 @@ mod tests { fn test_parse_frame_no_mask() { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); buf.extend(&[1u8]); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); assert!(Frame::parse(&mut buf, true, 1024).is_err()); @@ -492,7 +319,6 @@ mod tests { fn test_parse_frame_max_size() { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); buf.extend(&[1u8, 1u8]); - let mut buf = PayloadBuffer::new(once(Ok(buf.freeze()))); assert!(Frame::parse(&mut buf, true, 1).is_err()); @@ -504,35 +330,39 @@ mod tests { #[test] fn test_ping_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); + let mut buf = BytesMut::new(); + Frame::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); let mut v = vec![137u8, 4u8]; v.extend(b"data"); - assert_eq!(frame.0, v.into()); + assert_eq!(&buf[..], &v[..]); } #[test] fn test_pong_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false); + let mut buf = BytesMut::new(); + Frame::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); let mut v = vec![138u8, 4u8]; v.extend(b"data"); - assert_eq!(frame.0, v.into()); + assert_eq!(&buf[..], &v[..]); } #[test] fn test_close_frame() { + let mut buf = BytesMut::new(); let reason = (CloseCode::Normal, "data"); - let frame = Frame::close(Some(reason.into()), false); + Frame::write_close(&mut buf, Some(reason.into()), false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); - assert_eq!(frame.0, v.into()); + assert_eq!(&buf[..], &v[..]); } #[test] fn test_empty_close_frame() { - let frame = Frame::close(None, false); - assert_eq!(frame.0, vec![0x88, 0x00].into()); + let mut buf = BytesMut::new(); + Frame::write_close(&mut buf, None, false); + assert_eq!(&buf[..], &vec![0x88, 0x00][..]); } } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 61fad543c..e8bf3870d 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -1,24 +1,23 @@ -//! `WebSocket` support. +//! WebSocket protocol support. //! //! To setup a `WebSocket`, first do web socket handshake then on success //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to //! communicate with the peer. //! ``` -use bytes::Bytes; -use futures::{Async, Poll, Stream}; -use http::{header, Method, StatusCode}; +use std::io; -use body::Binary; -use error::{PayloadError, ResponseError}; -use payload::PayloadBuffer; +use error::ResponseError; +use http::{header, Method, StatusCode}; use request::Request; use response::{ConnectionType, Response, ResponseBuilder}; +mod codec; mod frame; mod mask; mod proto; -pub use self::frame::{Frame, FramedMessage}; +pub use self::codec::Message; +pub use self::frame::Frame; pub use self::proto::{CloseCode, CloseReason, OpCode}; /// Websocket protocol errors @@ -48,16 +47,16 @@ pub enum ProtocolError { /// Bad utf-8 encoding #[fail(display = "Bad utf-8 encoding.")] BadEncoding, - /// Payload error - #[fail(display = "Payload error: {}", _0)] - Payload(#[cause] PayloadError), + /// Io error + #[fail(display = "io error: {}", _0)] + Io(#[cause] io::Error), } impl ResponseError for ProtocolError {} -impl From for ProtocolError { - fn from(err: PayloadError) -> ProtocolError { - ProtocolError::Payload(err) +impl From for ProtocolError { + fn from(err: io::Error) -> ProtocolError { + ProtocolError::Io(err) } } @@ -109,21 +108,6 @@ impl ResponseError for HandshakeError { } } -/// `WebSocket` Message -#[derive(Debug, PartialEq)] -pub enum Message { - /// Text message - Text(String), - /// Binary message - Binary(Binary), - /// Ping message - Ping(String), - /// Pong message - Pong(String), - /// Close message with optional reason - Close(Option), -} - /// Prepare `WebSocket` handshake response. /// /// This function returns handshake `Response`, ready to send to peer. @@ -189,97 +173,6 @@ pub fn handshake(req: &Request) -> Result { .take()) } -/// Maps `Payload` stream into stream of `ws::Message` items -pub struct WsStream { - rx: PayloadBuffer, - closed: bool, - max_size: usize, -} - -impl WsStream -where - S: Stream, -{ - /// Create new websocket frames stream - pub fn new(stream: S) -> WsStream { - WsStream { - rx: PayloadBuffer::new(stream), - closed: false, - max_size: 65_536, - } - } - - /// Set max frame size - /// - /// By default max size is set to 64kb - pub fn max_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } -} - -impl Stream for WsStream -where - S: Stream, -{ - type Item = Message; - type Error = ProtocolError; - - fn poll(&mut self) -> Poll, Self::Error> { - if self.closed { - return Ok(Async::Ready(None)); - } - - match Frame::parse(&mut self.rx, true, self.max_size) { - Ok(Async::Ready(Some(frame))) => { - let (finished, opcode, payload) = frame.unpack(); - - // continuation is not supported - if !finished { - self.closed = true; - return Err(ProtocolError::NoContinuation); - } - - match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), - OpCode::Bad => { - self.closed = true; - Err(ProtocolError::BadOpCode) - } - OpCode::Close => { - self.closed = true; - let close_reason = Frame::parse_close_payload(&payload); - Ok(Async::Ready(Some(Message::Close(close_reason)))) - } - OpCode::Ping => Ok(Async::Ready(Some(Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into(), - )))), - OpCode::Pong => Ok(Async::Ready(Some(Message::Pong( - String::from_utf8_lossy(payload.as_ref()).into(), - )))), - OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { - self.closed = true; - Err(ProtocolError::BadEncoding) - } - } - } - } - } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - self.closed = true; - Err(e) - } - } - } -} - #[cfg(test)] mod tests { use super::*;