From 47b47af01a6baf8258c1286d3c82f6f28524e2ba Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 10 Oct 2018 13:20:00 -0700 Subject: [PATCH] refactor ws codec --- src/ws/codec.rs | 74 ++++++++++++++++++++++++++++++-------------- src/ws/frame.rs | 75 +++++++++++++++++++++++++-------------------- src/ws/mod.rs | 4 +-- src/ws/transport.rs | 6 ++-- tests/test_ws.rs | 21 ++++++++++--- 5 files changed, 113 insertions(+), 67 deletions(-) diff --git a/src/ws/codec.rs b/src/ws/codec.rs index 6e2b1209..7ba10672 100644 --- a/src/ws/codec.rs +++ b/src/ws/codec.rs @@ -1,7 +1,7 @@ use bytes::BytesMut; use tokio_codec::{Decoder, Encoder}; -use super::frame::Frame; +use super::frame::Parser; use super::proto::{CloseReason, OpCode}; use super::ProtocolError; use body::Binary; @@ -21,6 +21,21 @@ pub enum Message { Close(Option), } +/// `WebSocket` frame +#[derive(Debug, PartialEq)] +pub enum Frame { + /// Text frame, codec does not verify utf8 encoding + Text(Option), + /// Binary frame + Binary(Option), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), +} + /// WebSockets protocol codec pub struct Codec { max_size: usize, @@ -60,29 +75,29 @@ impl Encoder for Codec { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { Message::Text(txt) => { - Frame::write_message(dst, txt, OpCode::Text, true, !self.server) + Parser::write_message(dst, txt, OpCode::Text, true, !self.server) } Message::Binary(bin) => { - Frame::write_message(dst, bin, OpCode::Binary, true, !self.server) + Parser::write_message(dst, bin, OpCode::Binary, true, !self.server) } Message::Ping(txt) => { - Frame::write_message(dst, txt, OpCode::Ping, true, !self.server) + Parser::write_message(dst, txt, OpCode::Ping, true, !self.server) } Message::Pong(txt) => { - Frame::write_message(dst, txt, OpCode::Pong, true, !self.server) + Parser::write_message(dst, txt, OpCode::Pong, true, !self.server) } - Message::Close(reason) => Frame::write_close(dst, reason, !self.server), + Message::Close(reason) => Parser::write_close(dst, reason, !self.server), } Ok(()) } } impl Decoder for Codec { - type Item = Message; + type Item = Frame; type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - match Frame::parse(src, self.server, self.max_size) { + match Parser::parse(src, self.server, self.max_size) { Ok(Some((finished, opcode, payload))) => { // continuation is not supported if !finished { @@ -93,23 +108,36 @@ impl Decoder for Codec { OpCode::Continue => Err(ProtocolError::NoContinuation), OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Close => { - let close_reason = Frame::parse_close_payload(&payload); - Ok(Some(Message::Close(close_reason))) - } - OpCode::Ping => Ok(Some(Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into(), - ))), - OpCode::Pong => Ok(Some(Message::Pong( - String::from_utf8_lossy(payload.as_ref()).into(), - ))), - OpCode::Binary => Ok(Some(Message::Binary(payload))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => Ok(Some(Message::Text(s))), - Err(_) => Err(ProtocolError::BadEncoding), + if let Some(ref pl) = payload { + let close_reason = Parser::parse_close_payload(pl); + Ok(Some(Frame::Close(close_reason))) + } else { + Ok(Some(Frame::Close(None))) } } + OpCode::Ping => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Ping(String::new()))) + } + } + OpCode::Pong => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Pong(String::new()))) + } + } + OpCode::Binary => Ok(Some(Frame::Binary(payload))), + OpCode::Text => { + Ok(Some(Frame::Text(payload))) + //let tmp = Vec::from(payload.as_ref()); + //match String::from_utf8(tmp) { + // Ok(s) => Ok(Some(Message::Text(s))), + // Err(_) => Err(ProtocolError::BadEncoding), + //} + } } } Ok(None) => Ok(None), diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 38bebc28..de1b9239 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -9,9 +9,9 @@ use ws::ProtocolError; /// A struct representing a `WebSocket` frame. #[derive(Debug)] -pub struct Frame; +pub struct Parser; -impl Frame { +impl Parser { fn parse_metadata( src: &[u8], server: bool, max_size: usize, ) -> Result)>, ProtocolError> { @@ -87,10 +87,10 @@ impl Frame { /// Parse the input stream into a frame. pub fn parse( src: &mut BytesMut, server: bool, max_size: usize, - ) -> Result, ProtocolError> { + ) -> Result)>, ProtocolError> { // try to parse ws frame metadata let (idx, finished, opcode, length, mask) = - match Frame::parse_metadata(src, server, max_size)? { + match Parser::parse_metadata(src, server, max_size)? { None => return Ok(None), Some(res) => res, }; @@ -105,7 +105,7 @@ impl Frame { // no need for body if length == 0 { - return Ok(Some((finished, opcode, Binary::from("")))); + return Ok(Some((finished, opcode, None))); } let mut data = src.split_to(length); @@ -117,7 +117,7 @@ impl Frame { } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some((true, OpCode::Close, Binary::from("")))); + return Ok(Some((true, OpCode::Close, None))); } _ => (), } @@ -127,16 +127,16 @@ impl Frame { apply_mask(&mut data, mask); } - Ok(Some((finished, opcode, data.into()))) + Ok(Some((finished, opcode, Some(data)))) } /// Parse the payload of a close frame. - pub fn parse_close_payload(payload: &Binary) -> Option { + pub fn parse_close_payload(payload: &[u8]) -> Option { if payload.len() >= 2 { - let raw_code = NetworkEndian::read_u16(payload.as_ref()); + let raw_code = NetworkEndian::read_u16(payload); let code = CloseCode::from(raw_code); let description = if payload.len() > 2 { - Some(String::from_utf8_lossy(&payload.as_ref()[2..]).into()) + Some(String::from_utf8_lossy(&payload[2..]).into()) } else { None }; @@ -203,33 +203,40 @@ impl Frame { } }; - Frame::write_message(dst, payload, OpCode::Close, true, mask) + Parser::write_message(dst, payload, OpCode::Close, true, mask) } } #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; struct F { finished: bool, opcode: OpCode, - payload: Binary, + payload: Bytes, } - fn is_none(frm: &Result, ProtocolError>) -> bool { + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { match *frm { Ok(None) => true, _ => false, } } - fn extract(frm: Result, ProtocolError>) -> F { + fn extract( + frm: Result)>, ProtocolError>, + ) -> F { match frm { Ok(Some((finished, opcode, payload))) => F { finished, opcode, - payload, + payload: payload + .map(|b| b.freeze()) + .unwrap_or_else(|| Bytes::from("")), }, _ => unreachable!("error"), } @@ -238,12 +245,12 @@ mod tests { #[test] fn test_parse() { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); - assert!(is_none(&Frame::parse(&mut buf, false, 1024))); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); buf.extend(b"1"); - let frame = extract(Frame::parse(&mut buf, false, 1024)); + let frame = extract(Parser::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1"[..]); @@ -252,7 +259,7 @@ mod tests { #[test] fn test_parse_length0() { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); - let frame = extract(Frame::parse(&mut buf, false, 1024)); + let frame = extract(Parser::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert!(frame.payload.is_empty()); @@ -261,13 +268,13 @@ mod tests { #[test] fn test_parse_length2() { let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); - assert!(is_none(&Frame::parse(&mut buf, false, 1024))); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); buf.extend(&[0u8, 4u8][..]); buf.extend(b"1234"); - let frame = extract(Frame::parse(&mut buf, false, 1024)); + let frame = extract(Parser::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -276,13 +283,13 @@ mod tests { #[test] fn test_parse_length4() { let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); - assert!(is_none(&Frame::parse(&mut buf, false, 1024))); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(b"1234"); - let frame = extract(Frame::parse(&mut buf, false, 1024)); + let frame = extract(Parser::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -294,12 +301,12 @@ mod tests { buf.extend(b"0001"); buf.extend(b"1"); - assert!(Frame::parse(&mut buf, false, 1024).is_err()); + assert!(Parser::parse(&mut buf, false, 1024).is_err()); - let frame = extract(Frame::parse(&mut buf, true, 1024)); + let frame = extract(Parser::parse(&mut buf, true, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload, vec![1u8].into()); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); } #[test] @@ -307,12 +314,12 @@ mod tests { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); buf.extend(&[1u8]); - assert!(Frame::parse(&mut buf, true, 1024).is_err()); + assert!(Parser::parse(&mut buf, true, 1024).is_err()); - let frame = extract(Frame::parse(&mut buf, false, 1024)); + let frame = extract(Parser::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload, vec![1u8].into()); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); } #[test] @@ -320,9 +327,9 @@ mod tests { let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); buf.extend(&[1u8, 1u8]); - assert!(Frame::parse(&mut buf, true, 1).is_err()); + assert!(Parser::parse(&mut buf, true, 1).is_err()); - if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) { + if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) { } else { unreachable!("error"); } @@ -331,7 +338,7 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Frame::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -341,7 +348,7 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Frame::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); let mut v = vec![138u8, 4u8]; v.extend(b"data"); @@ -352,7 +359,7 @@ mod tests { fn test_close_frame() { let mut buf = BytesMut::new(); let reason = (CloseCode::Normal, "data"); - Frame::write_close(&mut buf, Some(reason.into()), false); + Parser::write_close(&mut buf, Some(reason.into()), false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); @@ -362,7 +369,7 @@ mod tests { #[test] fn test_empty_close_frame() { let mut buf = BytesMut::new(); - Frame::write_close(&mut buf, None, false); + Parser::write_close(&mut buf, None, false); assert_eq!(&buf[..], &vec![0x88, 0x00][..]); } } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 5ebb502b..7df1f4b4 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -16,8 +16,8 @@ mod mask; mod proto; mod transport; -pub use self::codec::{Codec, Message}; -pub use self::frame::Frame; +pub use self::codec::{Codec, Frame, Message}; +pub use self::frame::Parser; pub use self::proto::{CloseCode, CloseReason, OpCode}; pub use self::transport::Transport; diff --git a/src/ws/transport.rs b/src/ws/transport.rs index aabeb5d5..102d02b4 100644 --- a/src/ws/transport.rs +++ b/src/ws/transport.rs @@ -4,7 +4,7 @@ use actix_net::service::{IntoService, Service}; use futures::{Future, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; -use super::{Codec, Message}; +use super::{Codec, Frame, Message}; pub struct Transport where @@ -17,7 +17,7 @@ where impl Transport where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { @@ -37,7 +37,7 @@ where impl Future for Transport where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 73590990..91e212ef 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -20,12 +20,23 @@ use futures::{Future, Sink, Stream}; use actix_http::{h1, ws, ResponseError, ServiceConfig}; -fn ws_service(req: ws::Message) -> impl Future { +fn ws_service(req: ws::Frame) -> impl Future { match req { - ws::Message::Ping(msg) => ok(ws::Message::Pong(msg)), - ws::Message::Text(text) => ok(ws::Message::Text(text)), - ws::Message::Binary(bin) => ok(ws::Message::Binary(bin)), - ws::Message::Close(reason) => ok(ws::Message::Close(reason)), + ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)), + ws::Frame::Text(text) => { + let text = if let Some(pl) = text { + String::from_utf8(Vec::from(pl.as_ref())).unwrap() + } else { + String::new() + }; + ok(ws::Message::Text(text)) + } + ws::Frame::Binary(bin) => ok(ws::Message::Binary( + bin.map(|e| e.freeze()) + .unwrap_or_else(|| Bytes::from("")) + .into(), + )), + ws::Frame::Close(reason) => ok(ws::Message::Close(reason)), _ => ok(ws::Message::Close(None)), } }