From b4b3350b3e4b857daec140ff60981aaa1d4388ab Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 12 Dec 2019 14:06:54 +0600 Subject: [PATCH] Add websockets continuation frame support --- actix-http/CHANGES.md | 6 + actix-http/Cargo.toml | 15 ++- actix-http/src/ws/codec.rs | 220 +++++++++++++++++++++++++++------- actix-http/src/ws/frame.rs | 6 +- actix-http/src/ws/mod.rs | 17 +-- actix-http/tests/test_ws.rs | 66 +++++++++- actix-multipart/src/server.rs | 7 +- 7 files changed, 269 insertions(+), 68 deletions(-) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 26a670e9a..052b3eff3 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [1.0.0] - 2019-12-xx + +### Added + +* Add websockets continuation frame support + ## [1.0.0-alpha.5] - 2019-12-09 ### Fixed diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 8b9a1cc31..79e8777d6 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http" -version = "1.0.0-alpha.6" +version = "1.0.0" authors = ["Nikolay Kim "] description = "Actix http primitives" readme = "README.md" @@ -13,7 +13,6 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket"] license = "MIT/Apache-2.0" edition = "2018" -workspace = ".." [package.metadata.docs.rs] features = ["openssl", "rustls", "fail", "flate2-zlib", "secure-cookies"] @@ -47,26 +46,26 @@ secure-cookies = ["ring"] actix-service = "1.0.0" actix-codec = "0.2.0" actix-connect = "1.0.0" -actix-utils = "1.0.1" +actix-utils = "1.0.3" actix-rt = "1.0.0" -actix-threadpool = "0.3.0" +actix-threadpool = "0.3.1" actix-tls = { version = "1.0.0", optional = true } base64 = "0.11" -bitflags = "1.0" +bitflags = "1.2" bytes = "0.5.2" copyless = "0.1.4" chrono = "0.4.6" derive_more = "0.99.2" -either = "1.5.2" +either = "1.5.3" encoding_rs = "0.8" futures = "0.3.1" fxhash = "0.2.1" h2 = "0.2.1" http = "0.2.0" httparse = "1.3" -indexmap = "1.2" -lazy_static = "1.0" +indexmap = "1.3" +lazy_static = "1.4" language-tags = "0.2" log = "0.4" mime = "0.3" diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index fee8a632b..a37208a2b 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -12,6 +12,8 @@ pub enum Message { Text(String), /// Binary message Binary(Bytes), + /// Continuation + Continuation(Item), /// Ping message Ping(Bytes), /// Pong message @@ -26,9 +28,11 @@ pub enum Message { #[derive(Debug, PartialEq)] pub enum Frame { /// Text frame, codec does not verify utf8 encoding - Text(Option), + Text(Bytes), /// Binary frame - Binary(Option), + Binary(Bytes), + /// Continuation + Continuation(Item), /// Ping message Ping(Bytes), /// Pong message @@ -37,11 +41,28 @@ pub enum Frame { Close(Option), } +/// `WebSocket` continuation item +#[derive(Debug, PartialEq)] +pub enum Item { + FirstText(Bytes), + FirstBinary(Bytes), + Continue(Bytes), + Last(Bytes), +} + #[derive(Debug, Copy, Clone)] /// WebSockets protocol codec pub struct Codec { + flags: Flags, max_size: usize, - server: bool, +} + +bitflags::bitflags! { + struct Flags: u8 { + const SERVER = 0b0000_0001; + const CONTINUATION = 0b0000_0010; + const W_CONTINUATION = 0b0000_0100; + } } impl Codec { @@ -49,7 +70,7 @@ impl Codec { pub fn new() -> Codec { Codec { max_size: 65_536, - server: true, + flags: Flags::SERVER, } } @@ -65,7 +86,7 @@ impl Codec { /// /// By default decoder works in server mode. pub fn client_mode(mut self) -> Self { - self.server = false; + self.flags.remove(Flags::SERVER); self } } @@ -76,19 +97,94 @@ impl Encoder for Codec { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => { - Parser::write_message(dst, txt, OpCode::Text, true, !self.server) + Message::Text(txt) => Parser::write_message( + dst, + txt, + OpCode::Text, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Binary(bin) => Parser::write_message( + dst, + bin, + OpCode::Binary, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Ping(txt) => Parser::write_message( + dst, + txt, + OpCode::Ping, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Pong(txt) => Parser::write_message( + dst, + txt, + OpCode::Pong, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Close(reason) => { + Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) } - Message::Binary(bin) => { - Parser::write_message(dst, bin, OpCode::Binary, true, !self.server) - } - Message::Ping(txt) => { - Parser::write_message(dst, txt, OpCode::Ping, true, !self.server) - } - Message::Pong(txt) => { - Parser::write_message(dst, txt, OpCode::Pong, true, !self.server) - } - Message::Close(reason) => Parser::write_close(dst, reason, !self.server), + Message::Continuation(cont) => match cont { + Item::FirstText(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + return Err(ProtocolError::ContinuationStarted); + } else { + self.flags.insert(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Binary, + false, + !self.flags.contains(Flags::SERVER), + ) + } + } + Item::FirstBinary(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + return Err(ProtocolError::ContinuationStarted); + } else { + self.flags.insert(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Text, + false, + !self.flags.contains(Flags::SERVER), + ) + } + } + Item::Continue(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + Parser::write_message( + dst, + &data[..], + OpCode::Continue, + false, + !self.flags.contains(Flags::SERVER), + ) + } else { + return Err(ProtocolError::ContinuationNotStarted); + } + } + Item::Last(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + self.flags.remove(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Continue, + true, + !self.flags.contains(Flags::SERVER), + ) + } else { + return Err(ProtocolError::ContinuationNotStarted); + } + } + }, Message::Nop => (), } Ok(()) @@ -100,15 +196,64 @@ impl Decoder for Codec { type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - match Parser::parse(src, self.server, self.max_size) { + match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { Ok(Some((finished, opcode, payload))) => { // continuation is not supported if !finished { - return Err(ProtocolError::NoContinuation); + return match opcode { + OpCode::Continue => { + if self.flags.contains(Flags::CONTINUATION) { + Ok(Some(Frame::Continuation(Item::Continue( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationNotStarted) + } + } + OpCode::Binary => { + if !self.flags.contains(Flags::CONTINUATION) { + self.flags.insert(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::FirstBinary( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationStarted) + } + } + OpCode::Text => { + if !self.flags.contains(Flags::CONTINUATION) { + self.flags.insert(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::FirstText( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationStarted) + } + } + _ => { + error!("Unfinished fragment {:?}", opcode); + Err(ProtocolError::ContinuationFragment(opcode)) + } + }; } match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Continue => { + if self.flags.contains(Flags::CONTINUATION) { + self.flags.remove(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::Last( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationNotStarted) + } + } OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Close => { if let Some(ref pl) = payload { @@ -118,29 +263,18 @@ impl Decoder for Codec { Ok(Some(Frame::Close(None))) } } - OpCode::Ping => { - if let Some(pl) = payload { - Ok(Some(Frame::Ping(pl.freeze()))) - } else { - Ok(Some(Frame::Ping(Bytes::new()))) - } - } - OpCode::Pong => { - if let Some(pl) = payload { - Ok(Some(Frame::Pong(pl.freeze()))) - } else { - Ok(Some(Frame::Pong(Bytes::new()))) - } - } - OpCode::Binary => Ok(Some(Frame::Binary(payload))), - OpCode::Text => { - Ok(Some(Frame::Text(payload))) - //let tmp = Vec::from(payload.as_ref()); - //match String::from_utf8(tmp) { - // Ok(s) => Ok(Some(Message::Text(s))), - // Err(_) => Err(ProtocolError::BadEncoding), - //} - } + OpCode::Ping => Ok(Some(Frame::Ping( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Pong => Ok(Some(Frame::Pong( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Binary => Ok(Some(Frame::Binary( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Text => Ok(Some(Frame::Text( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), } } Ok(None) => Ok(None), diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 0949b711f..a280ff9c7 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -1,6 +1,6 @@ use std::convert::TryFrom; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use log::debug; use rand; @@ -154,14 +154,14 @@ impl Parser { } /// Generate binary representation - pub fn write_message>( + pub fn write_message>( dst: &mut BytesMut, pl: B, op: OpCode, fin: bool, mask: bool, ) { - let payload = pl.into(); + let payload = pl.as_ref(); let one: u8 = if fin { 0x80 | Into::::into(op) } else { diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index bc48c8e4e..ffa397979 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -18,7 +18,7 @@ mod frame; mod mask; mod proto; -pub use self::codec::{Codec, Frame, Message}; +pub use self::codec::{Codec, Frame, Item, Message}; pub use self::dispatcher::Dispatcher; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; @@ -44,12 +44,15 @@ pub enum ProtocolError { /// A payload reached size limit. #[display(fmt = "A payload reached size limit.")] Overflow, - /// Continuation is not supported - #[display(fmt = "Continuation is not supported.")] - NoContinuation, - /// Bad utf-8 encoding - #[display(fmt = "Bad utf-8 encoding.")] - BadEncoding, + /// Continuation is not started + #[display(fmt = "Continuation is not started.")] + ContinuationNotStarted, + /// Received new continuation but it is already started + #[display(fmt = "Received new continuation but it is already started")] + ContinuationStarted, + /// Unknown continuation fragment + #[display(fmt = "Unknown continuation fragment.")] + ContinuationFragment(OpCode), /// Io error #[display(fmt = "io error: {}", _0)] Io(io::Error), diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 6295ae283..5d70d24ae 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -2,7 +2,7 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; use actix_http_test::TestServer; use actix_utils::framed::Dispatcher; -use bytes::BytesMut; +use bytes::Bytes; use futures::future; use futures::{SinkExt, StreamExt}; @@ -25,9 +25,10 @@ async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + ws::Message::Text(String::from_utf8_lossy(&text).to_string()) } - ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Binary(bin) => ws::Message::Binary(bin), + ws::Frame::Continuation(item) => ws::Message::Continuation(item), ws::Frame::Close(reason) => ws::Message::Close(reason), _ => panic!(), }; @@ -52,7 +53,7 @@ async fn test_simple() { let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), - ws::Frame::Text(Some(BytesMut::from("text"))) + ws::Frame::Text(Bytes::from_static(b"text")) ); framed @@ -62,7 +63,7 @@ async fn test_simple() { let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), - ws::Frame::Binary(Some(BytesMut::from(&b"text"[..]).into())) + ws::Frame::Binary(Bytes::from_static(&b"text"[..])) ); framed.send(ws::Message::Ping("text".into())).await.unwrap(); @@ -72,6 +73,61 @@ async fn test_simple() { ws::Frame::Pong("text".to_string().into()) ); + framed + .send(ws::Message::Continuation(ws::Item::FirstText( + "text".into(), + ))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text"))) + ); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::FirstText( + "text".into() + ))) + .await + .is_err()); + assert!(framed + .send(ws::Message::Continuation(ws::Item::FirstBinary( + "text".into() + ))) + .await + .is_err()); + + framed + .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text"))) + ); + + framed + .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text"))) + ); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .await + .is_err()); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .await + .is_err()); + framed .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .await diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index dca94d7a2..2555cb7a3 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -880,14 +880,17 @@ mod tests { bytes: bytes, pos: 0, ready: false, - } + }; } } impl Stream for SlowStream { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { let this = self.get_mut(); if !this.ready { this.ready = true;