1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-27 17:52:56 +01:00

refactor ws codec

This commit is contained in:
Nikolay Kim 2018-10-10 13:20:00 -07:00
parent 4a167dc89e
commit 47b47af01a
5 changed files with 113 additions and 67 deletions

View File

@ -1,7 +1,7 @@
use bytes::BytesMut; use bytes::BytesMut;
use tokio_codec::{Decoder, Encoder}; use tokio_codec::{Decoder, Encoder};
use super::frame::Frame; use super::frame::Parser;
use super::proto::{CloseReason, OpCode}; use super::proto::{CloseReason, OpCode};
use super::ProtocolError; use super::ProtocolError;
use body::Binary; use body::Binary;
@ -21,6 +21,21 @@ pub enum Message {
Close(Option<CloseReason>), Close(Option<CloseReason>),
} }
/// `WebSocket` frame
#[derive(Debug, PartialEq)]
pub enum Frame {
/// Text frame, codec does not verify utf8 encoding
Text(Option<BytesMut>),
/// Binary frame
Binary(Option<BytesMut>),
/// Ping message
Ping(String),
/// Pong message
Pong(String),
/// Close message with optional reason
Close(Option<CloseReason>),
}
/// WebSockets protocol codec /// WebSockets protocol codec
pub struct Codec { pub struct Codec {
max_size: usize, max_size: usize,
@ -60,29 +75,29 @@ impl Encoder for Codec {
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item { match item {
Message::Text(txt) => { Message::Text(txt) => {
Frame::write_message(dst, txt, OpCode::Text, true, !self.server) Parser::write_message(dst, txt, OpCode::Text, true, !self.server)
} }
Message::Binary(bin) => { Message::Binary(bin) => {
Frame::write_message(dst, bin, OpCode::Binary, true, !self.server) Parser::write_message(dst, bin, OpCode::Binary, true, !self.server)
} }
Message::Ping(txt) => { Message::Ping(txt) => {
Frame::write_message(dst, txt, OpCode::Ping, true, !self.server) Parser::write_message(dst, txt, OpCode::Ping, true, !self.server)
} }
Message::Pong(txt) => { Message::Pong(txt) => {
Frame::write_message(dst, txt, OpCode::Pong, true, !self.server) Parser::write_message(dst, txt, OpCode::Pong, true, !self.server)
} }
Message::Close(reason) => Frame::write_close(dst, reason, !self.server), Message::Close(reason) => Parser::write_close(dst, reason, !self.server),
} }
Ok(()) Ok(())
} }
} }
impl Decoder for Codec { impl Decoder for Codec {
type Item = Message; type Item = Frame;
type Error = ProtocolError; type Error = ProtocolError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Frame::parse(src, self.server, self.max_size) { match Parser::parse(src, self.server, self.max_size) {
Ok(Some((finished, opcode, payload))) => { Ok(Some((finished, opcode, payload))) => {
// continuation is not supported // continuation is not supported
if !finished { if !finished {
@ -93,23 +108,36 @@ impl Decoder for Codec {
OpCode::Continue => Err(ProtocolError::NoContinuation), OpCode::Continue => Err(ProtocolError::NoContinuation),
OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Bad => Err(ProtocolError::BadOpCode),
OpCode::Close => { OpCode::Close => {
let close_reason = Frame::parse_close_payload(&payload); if let Some(ref pl) = payload {
Ok(Some(Message::Close(close_reason))) let close_reason = Parser::parse_close_payload(pl);
} Ok(Some(Frame::Close(close_reason)))
OpCode::Ping => Ok(Some(Message::Ping( } else {
String::from_utf8_lossy(payload.as_ref()).into(), Ok(Some(Frame::Close(None)))
))),
OpCode::Pong => Ok(Some(Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into(),
))),
OpCode::Binary => Ok(Some(Message::Binary(payload))),
OpCode::Text => {
let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) => Ok(Some(Message::Text(s))),
Err(_) => Err(ProtocolError::BadEncoding),
} }
} }
OpCode::Ping => {
if let Some(ref pl) = payload {
Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into())))
} else {
Ok(Some(Frame::Ping(String::new())))
}
}
OpCode::Pong => {
if let Some(ref pl) = payload {
Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into())))
} else {
Ok(Some(Frame::Pong(String::new())))
}
}
OpCode::Binary => Ok(Some(Frame::Binary(payload))),
OpCode::Text => {
Ok(Some(Frame::Text(payload)))
//let tmp = Vec::from(payload.as_ref());
//match String::from_utf8(tmp) {
// Ok(s) => Ok(Some(Message::Text(s))),
// Err(_) => Err(ProtocolError::BadEncoding),
//}
}
} }
} }
Ok(None) => Ok(None), Ok(None) => Ok(None),

View File

@ -9,9 +9,9 @@ use ws::ProtocolError;
/// A struct representing a `WebSocket` frame. /// A struct representing a `WebSocket` frame.
#[derive(Debug)] #[derive(Debug)]
pub struct Frame; pub struct Parser;
impl Frame { impl Parser {
fn parse_metadata( fn parse_metadata(
src: &[u8], server: bool, max_size: usize, src: &[u8], server: bool, max_size: usize,
) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> { ) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
@ -87,10 +87,10 @@ impl Frame {
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
pub fn parse( pub fn parse(
src: &mut BytesMut, server: bool, max_size: usize, src: &mut BytesMut, server: bool, max_size: usize,
) -> Result<Option<(bool, OpCode, Binary)>, ProtocolError> { ) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
// try to parse ws frame metadata // try to parse ws frame metadata
let (idx, finished, opcode, length, mask) = let (idx, finished, opcode, length, mask) =
match Frame::parse_metadata(src, server, max_size)? { match Parser::parse_metadata(src, server, max_size)? {
None => return Ok(None), None => return Ok(None),
Some(res) => res, Some(res) => res,
}; };
@ -105,7 +105,7 @@ impl Frame {
// no need for body // no need for body
if length == 0 { if length == 0 {
return Ok(Some((finished, opcode, Binary::from("")))); return Ok(Some((finished, opcode, None)));
} }
let mut data = src.split_to(length); let mut data = src.split_to(length);
@ -117,7 +117,7 @@ impl Frame {
} }
OpCode::Close if length > 125 => { OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Some((true, OpCode::Close, Binary::from("")))); return Ok(Some((true, OpCode::Close, None)));
} }
_ => (), _ => (),
} }
@ -127,16 +127,16 @@ impl Frame {
apply_mask(&mut data, mask); apply_mask(&mut data, mask);
} }
Ok(Some((finished, opcode, data.into()))) Ok(Some((finished, opcode, Some(data))))
} }
/// Parse the payload of a close frame. /// Parse the payload of a close frame.
pub fn parse_close_payload(payload: &Binary) -> Option<CloseReason> { pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
if payload.len() >= 2 { if payload.len() >= 2 {
let raw_code = NetworkEndian::read_u16(payload.as_ref()); let raw_code = NetworkEndian::read_u16(payload);
let code = CloseCode::from(raw_code); let code = CloseCode::from(raw_code);
let description = if payload.len() > 2 { let description = if payload.len() > 2 {
Some(String::from_utf8_lossy(&payload.as_ref()[2..]).into()) Some(String::from_utf8_lossy(&payload[2..]).into())
} else { } else {
None None
}; };
@ -203,33 +203,40 @@ impl Frame {
} }
}; };
Frame::write_message(dst, payload, OpCode::Close, true, mask) Parser::write_message(dst, payload, OpCode::Close, true, mask)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use bytes::Bytes;
struct F { struct F {
finished: bool, finished: bool,
opcode: OpCode, opcode: OpCode,
payload: Binary, payload: Bytes,
} }
fn is_none(frm: &Result<Option<(bool, OpCode, Binary)>, ProtocolError>) -> bool { fn is_none(
frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
) -> bool {
match *frm { match *frm {
Ok(None) => true, Ok(None) => true,
_ => false, _ => false,
} }
} }
fn extract(frm: Result<Option<(bool, OpCode, Binary)>, ProtocolError>) -> F { fn extract(
frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
) -> F {
match frm { match frm {
Ok(Some((finished, opcode, payload))) => F { Ok(Some((finished, opcode, payload))) => F {
finished, finished,
opcode, opcode,
payload, payload: payload
.map(|b| b.freeze())
.unwrap_or_else(|| Bytes::from("")),
}, },
_ => unreachable!("error"), _ => unreachable!("error"),
} }
@ -238,12 +245,12 @@ mod tests {
#[test] #[test]
fn test_parse() { fn test_parse() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend(b"1"); buf.extend(b"1");
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1"[..]); assert_eq!(frame.payload.as_ref(), &b"1"[..]);
@ -252,7 +259,7 @@ mod tests {
#[test] #[test]
fn test_parse_length0() { fn test_parse_length0() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert!(frame.payload.is_empty()); assert!(frame.payload.is_empty());
@ -261,13 +268,13 @@ mod tests {
#[test] #[test]
fn test_parse_length2() { fn test_parse_length2() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
buf.extend(&[0u8, 4u8][..]); buf.extend(&[0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
@ -276,13 +283,13 @@ mod tests {
#[test] #[test]
fn test_parse_length4() { fn test_parse_length4() {
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
buf.extend(b"1234"); buf.extend(b"1234");
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), &b"1234"[..]); assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
@ -294,12 +301,12 @@ mod tests {
buf.extend(b"0001"); buf.extend(b"0001");
buf.extend(b"1"); buf.extend(b"1");
assert!(Frame::parse(&mut buf, false, 1024).is_err()); assert!(Parser::parse(&mut buf, false, 1024).is_err());
let frame = extract(Frame::parse(&mut buf, true, 1024)); let frame = extract(Parser::parse(&mut buf, true, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, vec![1u8].into()); assert_eq!(frame.payload, Bytes::from(vec![1u8]));
} }
#[test] #[test]
@ -307,12 +314,12 @@ mod tests {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
buf.extend(&[1u8]); buf.extend(&[1u8]);
assert!(Frame::parse(&mut buf, true, 1024).is_err()); assert!(Parser::parse(&mut buf, true, 1024).is_err());
let frame = extract(Frame::parse(&mut buf, false, 1024)); let frame = extract(Parser::parse(&mut buf, false, 1024));
assert!(!frame.finished); assert!(!frame.finished);
assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, vec![1u8].into()); assert_eq!(frame.payload, Bytes::from(vec![1u8]));
} }
#[test] #[test]
@ -320,9 +327,9 @@ mod tests {
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
buf.extend(&[1u8, 1u8]); buf.extend(&[1u8, 1u8]);
assert!(Frame::parse(&mut buf, true, 1).is_err()); assert!(Parser::parse(&mut buf, true, 1).is_err());
if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) { if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
} else { } else {
unreachable!("error"); unreachable!("error");
} }
@ -331,7 +338,7 @@ mod tests {
#[test] #[test]
fn test_ping_frame() { fn test_ping_frame() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
Frame::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
let mut v = vec![137u8, 4u8]; let mut v = vec![137u8, 4u8];
v.extend(b"data"); v.extend(b"data");
@ -341,7 +348,7 @@ mod tests {
#[test] #[test]
fn test_pong_frame() { fn test_pong_frame() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
Frame::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
let mut v = vec![138u8, 4u8]; let mut v = vec![138u8, 4u8];
v.extend(b"data"); v.extend(b"data");
@ -352,7 +359,7 @@ mod tests {
fn test_close_frame() { fn test_close_frame() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
let reason = (CloseCode::Normal, "data"); let reason = (CloseCode::Normal, "data");
Frame::write_close(&mut buf, Some(reason.into()), false); Parser::write_close(&mut buf, Some(reason.into()), false);
let mut v = vec![136u8, 6u8, 3u8, 232u8]; let mut v = vec![136u8, 6u8, 3u8, 232u8];
v.extend(b"data"); v.extend(b"data");
@ -362,7 +369,7 @@ mod tests {
#[test] #[test]
fn test_empty_close_frame() { fn test_empty_close_frame() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
Frame::write_close(&mut buf, None, false); Parser::write_close(&mut buf, None, false);
assert_eq!(&buf[..], &vec![0x88, 0x00][..]); assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
} }
} }

View File

@ -16,8 +16,8 @@ mod mask;
mod proto; mod proto;
mod transport; mod transport;
pub use self::codec::{Codec, Message}; pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Frame; pub use self::frame::Parser;
pub use self::proto::{CloseCode, CloseReason, OpCode}; pub use self::proto::{CloseCode, CloseReason, OpCode};
pub use self::transport::Transport; pub use self::transport::Transport;

View File

@ -4,7 +4,7 @@ use actix_net::service::{IntoService, Service};
use futures::{Future, Poll}; use futures::{Future, Poll};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use super::{Codec, Message}; use super::{Codec, Frame, Message};
pub struct Transport<S, T> pub struct Transport<S, T>
where where
@ -17,7 +17,7 @@ where
impl<S, T> Transport<S, T> impl<S, T> Transport<S, T>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Message, Response = Message>, S: Service<Request = Frame, Response = Message>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
@ -37,7 +37,7 @@ where
impl<S, T> Future for Transport<S, T> impl<S, T> Future for Transport<S, T>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Message, Response = Message>, S: Service<Request = Frame, Response = Message>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {

View File

@ -20,12 +20,23 @@ use futures::{Future, Sink, Stream};
use actix_http::{h1, ws, ResponseError, ServiceConfig}; use actix_http::{h1, ws, ResponseError, ServiceConfig};
fn ws_service(req: ws::Message) -> impl Future<Item = ws::Message, Error = io::Error> { fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> {
match req { match req {
ws::Message::Ping(msg) => ok(ws::Message::Pong(msg)), ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)),
ws::Message::Text(text) => ok(ws::Message::Text(text)), ws::Frame::Text(text) => {
ws::Message::Binary(bin) => ok(ws::Message::Binary(bin)), let text = if let Some(pl) = text {
ws::Message::Close(reason) => ok(ws::Message::Close(reason)), String::from_utf8(Vec::from(pl.as_ref())).unwrap()
} else {
String::new()
};
ok(ws::Message::Text(text))
}
ws::Frame::Binary(bin) => ok(ws::Message::Binary(
bin.map(|e| e.freeze())
.unwrap_or_else(|| Bytes::from(""))
.into(),
)),
ws::Frame::Close(reason) => ok(ws::Message::Close(reason)),
_ => ok(ws::Message::Close(None)), _ => ok(ws::Message::Close(None)),
} }
} }