diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 3f81ea9f0..e94c6745e 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,10 +1,7 @@ [package] name = "actix-http" version = "3.9.0" -authors = [ - "Nikolay Kim ", - "Rob Ede ", -] +authors = ["Nikolay Kim ", "Rob Ede "] description = "HTTP types and services for the Actix ecosystem" keywords = ["actix", "http", "framework", "async", "futures"] homepage = "https://actix.rs" @@ -32,6 +29,7 @@ features = [ "compress-brotli", "compress-gzip", "compress-zstd", + "compress-ws-deflate", ] [package.metadata.cargo_check_external_types] @@ -62,12 +60,7 @@ default = [] http2 = ["dep:h2"] # WebSocket protocol implementation -ws = [ - "dep:local-channel", - "dep:base64", - "dep:rand", - "dep:sha1", -] +ws = ["dep:local-channel", "dep:base64", "dep:rand", "dep:sha1"] # TLS via OpenSSL openssl = ["__tls", "actix-tls/accept", "actix-tls/openssl"] @@ -89,8 +82,9 @@ rustls-0_23 = ["__tls", "actix-tls/accept", "actix-tls/rustls-0_23"] # Compression codecs compress-brotli = ["__compress", "dep:brotli"] -compress-gzip = ["__compress", "dep:flate2"] -compress-zstd = ["__compress", "dep:zstd"] +compress-gzip = ["__compress", "dep:flate2"] +compress-zstd = ["__compress", "dep:zstd"] +compress-ws-deflate = ["dep:flate2", "flate2/zlib-default"] # Internal (PRIVATE!) features used to aid testing and checking feature status. # Don't rely on these whatsoever. They are semver-exempt and may disappear at anytime. @@ -112,7 +106,9 @@ bytes = "1" bytestring = "1" derive_more = { version = "1", features = ["as_ref", "deref", "deref_mut", "display", "error", "from"] } encoding_rs = "0.8" -futures-core = { version = "0.3.17", default-features = false, features = ["alloc"] } +futures-core = { version = "0.3.17", default-features = false, features = [ + "alloc", +] } http = "0.2.7" httparse = "1.5.1" httpdate = "1.0.1" @@ -146,14 +142,19 @@ zstd = { version = "0.13", optional = true } [dev-dependencies] actix-http-test = { version = "3", features = ["openssl"] } actix-server = "2" -actix-tls = { version = "3.4", features = ["openssl", "rustls-0_23-webpki-roots"] } +actix-tls = { version = "3.4", features = [ + "openssl", + "rustls-0_23-webpki-roots", +] } actix-web = "4" async-stream = "0.3" criterion = { version = "0.5", features = ["html_reports"] } divan = "0.1.8" env_logger = "0.11" -futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] } +futures-util = { version = "0.3.17", default-features = false, features = [ + "alloc", +] } memchr = "2.4" once_cell = "1.9" rcgen = "0.13" diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index fe2ca43be..6c096eb4b 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -1,9 +1,13 @@ use bitflags::bitflags; use bytes::{Bytes, BytesMut}; use bytestring::ByteString; -use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::codec; use tracing::error; +#[cfg(feature = "compress-ws-deflate")] +use super::deflate::{ + DeflateCompressionContext, DeflateContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG, +}; use super::{ frame::Parser, proto::{CloseReason, OpCode, RsvBits}, @@ -66,16 +70,6 @@ pub enum Item { Last(Bytes), } -/// WebSocket protocol codec. -#[derive(Debug, Clone)] -pub struct Codec { - flags: Flags, - max_size: usize, - - inbound_rsv_bits: Option, - outbound_rsv_bits: RsvBits, -} - bitflags! { #[derive(Debug, Clone, Copy)] struct Flags: u8 { @@ -85,81 +79,116 @@ bitflags! { } } -impl Codec { - /// Create new WebSocket frames decoder. - pub const fn new() -> Codec { - Codec { - max_size: 65_536, +/// WebSocket message encoder. +#[derive(Debug, Clone)] +pub struct Encoder { + flags: Flags, + + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: Option, +} + +impl Encoder { + /// Create new WebSocket frames encoder. + pub const fn new() -> Encoder { + Encoder { flags: Flags::SERVER, - inbound_rsv_bits: None, - outbound_rsv_bits: RsvBits::empty(), + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: None, } } - /// Set max frame size. - /// - /// By default max size is set to 64KiB. - #[must_use = "This returns the a new Codec, without modifying the original."] - pub fn max_size(mut self, size: usize) -> Self { - self.max_size = size; + /// Create new WebSocket frames encoder with `permessage-deflate` extension support. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder { + Encoder { + flags: Flags::SERVER, + + deflate_compress: Some(compress), + } + } + + fn set_client_mode(mut self) -> Self { + self.flags = Flags::empty(); self } - /// Set decoder to client mode. - /// - /// By default decoder works in server mode. - #[must_use = "This returns the a new Codec, without modifying the original."] - pub fn client_mode(mut self) -> Self { - self.flags.remove(Flags::SERVER); + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self.deflate_compress = self + .deflate_compress + .map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits)); self } - /// Get inbound RSV bits. - /// - /// Returns None if there's no received frame yet. - pub fn get_inbound_rsv_bits(&self) -> Option { - self.inbound_rsv_bits + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + if let Some(compress) = &mut self.deflate_compress { + Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG)) + } else { + Ok((bytes, RsvBits::empty())) + } } - /// Set outbound RSV bits. - pub fn set_outbound_rsv_bits(&mut self, rsv_bits: RsvBits) { - self.outbound_rsv_bits = rsv_bits; + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + Ok((bytes, RsvBits::empty())) } } -impl Default for Codec { +impl Default for Encoder { fn default() -> Self { Self::new() } } -impl Encoder for Codec { +impl codec::Encoder for Encoder { type Error = ProtocolError; fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => Parser::write_message( - dst, - txt, - OpCode::Text, - self.outbound_rsv_bits, - true, - !self.flags.contains(Flags::SERVER), - ), - Message::Binary(bin) => Parser::write_message( - dst, - bin, - OpCode::Binary, - self.outbound_rsv_bits, - true, - !self.flags.contains(Flags::SERVER), - ), + Message::Text(txt) => { + let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?; + + Parser::write_message( + dst, + bytes, + OpCode::Text, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } + Message::Binary(bin) => { + let (bin, rsv_bits) = self.process_payload(true, bin)?; + + Parser::write_message( + dst, + bin, + OpCode::Binary, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } Message::Ping(txt) => Parser::write_message( dst, txt, OpCode::Ping, - self.outbound_rsv_bits, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), @@ -167,14 +196,14 @@ impl Encoder for Codec { dst, txt, OpCode::Pong, - self.outbound_rsv_bits, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), Message::Close(reason) => Parser::write_close( dst, reason, - self.outbound_rsv_bits, + RsvBits::empty(), !self.flags.contains(Flags::SERVER), ), Message::Continuation(cont) => match cont { @@ -182,12 +211,14 @@ impl Encoder for Codec { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Text, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -197,12 +228,14 @@ impl Encoder for Codec { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Binary, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -210,11 +243,13 @@ impl Encoder for Codec { } Item::Continue(data) => { if self.flags.contains(Flags::W_CONTINUATION) { + let (data, rsv_bits) = self.process_payload(false, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -225,11 +260,14 @@ impl Encoder for Codec { Item::Last(data) => { if self.flags.contains(Flags::W_CONTINUATION) { self.flags.remove(Flags::W_CONTINUATION); + + let (data, rsv_bits) = self.process_payload(true, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, - self.outbound_rsv_bits, + rsv_bits, true, !self.flags.contains(Flags::SERVER), ) @@ -244,21 +282,120 @@ impl Encoder for Codec { } } -impl Decoder for Codec { +/// WebSocket message decoder. +#[derive(Debug, Clone)] +pub struct Decoder { + flags: Flags, + max_size: usize, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: Option, +} + +impl Decoder { + /// Create new WebSocket frames decoder. + pub const fn new() -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: None, + } + } + + /// Create new WebSocket frames decoder with `permessage-deflate` extension support. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + deflate_decompress: Some(decompress), + } + } + + fn set_client_mode(mut self) -> Self { + self.flags = Flags::empty(); + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + local_no_context_takeover: bool, + local_max_window_bits: u8, + ) -> Self { + if let Some(decompress) = &mut self.deflate_decompress { + decompress.reset_with(local_no_context_takeover, local_max_window_bits); + } + + self + } + + fn set_max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + opcode: OpCode, + rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + if let Some(bytes) = bytes { + if let Some(decompress) = &mut self.deflate_decompress { + Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?)) + } else { + Ok(Some(bytes)) + } + } else { + Ok(None) + } + } + + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + _opcode: OpCode, + _rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + Ok(bytes) + } +} + +impl Default for Decoder { + fn default() -> Self { + Self::new() + } +} + +impl codec::Decoder for Decoder { type Item = Frame; type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { Ok(Some((finished, opcode, rsv_bits, payload))) => { - self.inbound_rsv_bits = Some(rsv_bits); + let payload = self.process_payload( + finished, + opcode, + rsv_bits, + payload.map(BytesMut::freeze), + )?; + // continuation is not supported if !finished { return match opcode { OpCode::Continue => { if self.flags.contains(Flags::CONTINUATION) { Ok(Some(Frame::Continuation(Item::Continue( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -268,7 +405,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstBinary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -278,7 +415,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstText( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -296,7 +433,7 @@ impl Decoder for Codec { if self.flags.contains(Flags::CONTINUATION) { self.flags.remove(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::Last( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -311,18 +448,10 @@ impl Decoder for Codec { Ok(Some(Frame::Close(None))) } } - OpCode::Ping => Ok(Some(Frame::Ping( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Pong => Ok(Some(Frame::Pong( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Binary => Ok(Some(Frame::Binary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Text => Ok(Some(Frame::Text( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), + OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))), + OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))), + OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))), + OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))), } } Ok(None) => Ok(None), @@ -330,3 +459,95 @@ impl Decoder for Codec { } } } + +/// WebSocket protocol codec. +#[derive(Debug, Default, Clone)] +pub struct Codec { + encoder: Encoder, + decoder: Decoder, +} + +impl Codec { + /// Create new WebSocket frames codec. + pub fn new() -> Codec { + Codec { + encoder: Encoder::new(), + decoder: Decoder::new(), + } + } + + /// Create new WebSocket frames codec with DEFLATE compression. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(context: DeflateContext) -> Codec { + let DeflateContext { + compress, + decompress, + } = context; + + Codec { + encoder: Encoder::new_deflate(compress), + decoder: Decoder::new_deflate(decompress), + } + } + + /// Set max frame size. + /// + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn max_size(self, size: usize) -> Self { + let Self { encoder, decoder } = self; + + Codec { + encoder, + decoder: decoder.set_max_size(size), + } + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn client_mode(self) -> Self { + let Self { + mut encoder, + mut decoder, + } = self; + + encoder = encoder.set_client_mode(); + decoder = decoder.set_client_mode(); + #[cfg(feature = "compress-ws-deflate")] + { + if let Some(decoder) = &decoder.deflate_decompress { + encoder = encoder.set_client_mode_deflate( + decoder.local_no_context_takeover, + decoder.local_max_window_bits, + ); + } + if let Some(encoder) = &encoder.deflate_compress { + decoder = decoder.set_client_mode_deflate( + encoder.remote_no_context_takeover, + encoder.remote_max_window_bits, + ); + } + } + + Self { encoder, decoder } + } +} + +impl codec::Decoder for Codec { + type Item = Frame; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + self.decoder.decode(src) + } +} + +impl codec::Encoder for Codec { + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.encoder.encode(item, dst) + } +} diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs new file mode 100644 index 000000000..bb256c18f --- /dev/null +++ b/actix-http/src/ws/deflate.rs @@ -0,0 +1,516 @@ +use std::convert::Infallible; + +use bytes::Bytes; +pub use flate2::Compression as DeflateCompressionLevel; + +use super::{OpCode, ProtocolError, RsvBits}; +use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS}; + +const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive = 9..=15; +const DEFAULT_WINDOW_BITS: u8 = 15; +const BUF_SIZE: usize = 2048; + +pub(super) const RSV_BIT_DEFLATE_FLAG: RsvBits = RsvBits::RSV1; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum DeflateHandshakeError { + UnknownWebSocketParameters, + DuplicateParameter(&'static str), + MaxWindowBitsOutOfRange, + NoSuitableConfigurationFound, +} + +impl std::fmt::Display for DeflateHandshakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UnknownWebSocketParameters => { + write!(f, "Unknown WebSocket `permessage-deflate` parameters.") + } + Self::DuplicateParameter(p) => { + write!(f, "Duplicate WebSocket `permessage-deflate` parameter: {p}") + } + Self::MaxWindowBitsOutOfRange => write!( + f, + "Max window bits out of range. ({} to {} expected)", + MAX_WINDOW_BITS_RANGE.start(), + MAX_WINDOW_BITS_RANGE.end() + ), + Self::NoSuitableConfigurationFound => write!( + f, + "No suitable WebSocket `permedia-deflate` parameter configurations found." + ), + } + } +} + +impl std::error::Error for DeflateHandshakeError {} + +#[derive(Copy, Clone, Debug)] +pub enum ClientMaxWindowBits { + NotSpecified, + Specified(u8), +} + +#[derive(Debug, Clone, Default)] +pub struct DeflateSessionParameters { + pub server_no_context_takeover: bool, + pub client_no_context_takeover: bool, + pub server_max_window_bits: Option, + pub client_max_window_bits: Option, +} + +impl TryIntoHeaderPair for DeflateSessionParameters { + type Error = Infallible; + + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let mut response_extension = vec!["permessage-deflate".to_owned()]; + + if self.server_no_context_takeover { + response_extension.push("server_no_context_takeover".to_owned()); + } + if self.client_no_context_takeover { + response_extension.push("client_no_context_takeover".to_owned()); + } + if let Some(server_max_window_bits) = self.server_max_window_bits { + response_extension.push(format!("server_max_window_bits={server_max_window_bits}")); + } + if let Some(client_max_window_bits) = self.client_max_window_bits { + match client_max_window_bits { + ClientMaxWindowBits::NotSpecified => { + response_extension.push("client_max_window_bits".to_string()); + } + ClientMaxWindowBits::Specified(bits) => { + response_extension.push(format!("client_max_window_bits={bits}")); + } + } + } + + Ok(( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_extension.join("; ")).unwrap(), + )) + } +} + +impl DeflateSessionParameters { + fn parse<'a>( + extension_frags: impl Iterator, + ) -> Result { + let mut client_max_window_bits = None; + let mut server_max_window_bits = None; + let mut client_no_context_takeover = None; + let mut server_no_context_takeover = None; + + let mut unknown_parameters = vec![]; + + for fragment in extension_frags { + if fragment == "client_max_window_bits" { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + client_max_window_bits = Some(ClientMaxWindowBits::NotSpecified); + } else if let Some(value) = fragment.strip_prefix("client_max_window_bits=") { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + client_max_window_bits = Some(ClientMaxWindowBits::Specified(bits)); + } else if let Some(value) = fragment.strip_prefix("server_max_window_bits=") { + if server_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + server_max_window_bits = Some(bits); + } else if fragment == "server_no_context_takeover" { + if server_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_no_context_takeover", + )); + } + server_no_context_takeover = Some(true); + } else if fragment == "client_no_context_takeover" { + if client_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_no_context_takeover", + )); + } + client_no_context_takeover = Some(true); + } else { + unknown_parameters.push(fragment.to_owned()); + } + } + + if !unknown_parameters.is_empty() { + Err(DeflateHandshakeError::UnknownWebSocketParameters) + } else { + Ok(DeflateSessionParameters { + server_no_context_takeover: server_no_context_takeover.unwrap_or(false), + client_no_context_takeover: client_no_context_takeover.unwrap_or(false), + server_max_window_bits, + client_max_window_bits, + }) + } + } + + pub fn from_extension_header(header_value: &str) -> Vec> { + let mut results = vec![]; + for extension in header_value.split(',').map(str::trim) { + let mut fragments = extension.split(';').map(str::trim); + if fragments.next() == Some("permessage-deflate") { + results.push(Self::parse(fragments)); + } + } + + results + } + + pub fn create_context( + &self, + compression_level: Option, + is_client_mode: bool, + ) -> DeflateContext { + let client_max_window_bits = + if let Some(ClientMaxWindowBits::Specified(value)) = self.client_max_window_bits { + value + } else { + DEFAULT_WINDOW_BITS + }; + let server_max_window_bits = self.server_max_window_bits.unwrap_or(DEFAULT_WINDOW_BITS); + + let (remote_no_context_takeover, remote_max_window_bits) = if is_client_mode { + (self.server_no_context_takeover, server_max_window_bits) + } else { + (self.client_no_context_takeover, client_max_window_bits) + }; + + let (local_no_context_takeover, local_max_window_bits) = if is_client_mode { + (self.client_no_context_takeover, client_max_window_bits) + } else { + (self.server_no_context_takeover, server_max_window_bits) + }; + + DeflateContext { + compress: DeflateCompressionContext::new( + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + ), + decompress: DeflateDecompressionContext::new( + local_no_context_takeover, + local_max_window_bits, + ), + } + } +} + +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct DeflateServerConfig { + pub compression_level: Option, + + pub server_no_context_takeover: bool, + pub client_no_context_takeover: bool, + pub server_max_window_bits: Option, + pub client_max_window_bits: Option, +} + +impl DeflateServerConfig { + pub fn negotiate(&self, params: DeflateSessionParameters) -> DeflateSessionParameters { + let server_no_context_takeover = + if self.server_no_context_takeover && !params.server_no_context_takeover { + true + } else { + params.server_no_context_takeover + }; + + let client_no_context_takeover = + if self.client_no_context_takeover && !params.client_no_context_takeover { + true + } else { + params.client_no_context_takeover + }; + + let server_max_window_bits = + match (self.server_max_window_bits, params.server_max_window_bits) { + (None, value) => value, + (Some(config_value), None) => Some(config_value), + (Some(config_value), Some(value)) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + let client_max_window_bits = + match (self.client_max_window_bits, params.client_max_window_bits) { + (None, None | Some(ClientMaxWindowBits::NotSpecified)) => None, + (None, Some(ClientMaxWindowBits::Specified(value))) => Some(value), + (Some(_), None) => None, + (Some(config_value), Some(ClientMaxWindowBits::NotSpecified)) => Some(config_value), + (Some(config_value), Some(ClientMaxWindowBits::Specified(value))) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + DeflateSessionParameters { + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits: client_max_window_bits.map(ClientMaxWindowBits::Specified), + } + } +} + +#[derive(Debug)] +pub struct DeflateDecompressionContext { + pub(super) local_no_context_takeover: bool, + pub(super) local_max_window_bits: u8, + + decompress: flate2::Decompress, + + decode_continuation: bool, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl Clone for DeflateDecompressionContext { + fn clone(&self) -> Self { + // Create with empty context because the context is not meant to be cloned. + Self::new(self.local_no_context_takeover, self.local_max_window_bits) + } +} + +impl DeflateDecompressionContext { + fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self { + Self { + local_no_context_takeover, + local_max_window_bits, + + decompress: flate2::Decompress::new_with_window_bits(false, local_max_window_bits), + + decode_continuation: false, + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub fn reset_with(&mut self, local_no_context_takeover: bool, local_max_window_bits: u8) { + *self = Self::new(local_no_context_takeover, local_max_window_bits); + } + + pub fn decompress( + &mut self, + fin: bool, + opcode: OpCode, + rsv: RsvBits, + payload: Bytes, + ) -> Result { + if !matches!(opcode, OpCode::Text | OpCode::Binary | OpCode::Continue) + || !rsv.contains(RSV_BIT_DEFLATE_FLAG) + { + return Ok(payload); + } + + if opcode == OpCode::Continue { + if !self.decode_continuation { + return Ok(payload); + } + } else { + self.decode_continuation = true; + } + + let mut output: Vec = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + let mut offset: usize = 0; + loop { + let res = if offset >= payload.len() { + self.decompress + .decompress( + &[0x00, 0x00, 0xff, 0xff], + &mut buf, + flate2::FlushDecompress::Finish, + ) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + } else { + self.decompress + .decompress(&payload[offset..], &mut buf, flate2::FlushDecompress::None) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + }; + + let read = self.decompress.total_in() - self.total_bytes_read; + let written = self.decompress.total_out() - self.total_bytes_written; + + offset += read as usize; + self.total_bytes_read += read; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + + if fin { + self.decode_continuation = false; + if self.local_no_context_takeover { + self.reset(); + } + } + + Ok(output.into()) + } + + pub(super) fn reset(&mut self) { + self.decompress.reset(false); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +#[derive(Debug)] +pub struct DeflateCompressionContext { + compression_level: flate2::Compression, + pub(super) remote_no_context_takeover: bool, + pub(super) remote_max_window_bits: u8, + + compress: flate2::Compress, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl Clone for DeflateCompressionContext { + fn clone(&self) -> Self { + // Create with empty context because the context is not meant to be cloned. + Self::new( + Some(self.compression_level), + self.remote_no_context_takeover, + self.remote_max_window_bits, + ) + } +} + +impl DeflateCompressionContext { + fn new( + compression_level: Option, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + let compression_level = compression_level.unwrap_or_default(); + + Self { + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + + compress: flate2::Compress::new_with_window_bits( + compression_level, + false, + remote_max_window_bits, + ), + + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub fn reset_with( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self = Self::new( + Some(self.compression_level), + remote_no_context_takeover, + remote_max_window_bits, + ); + + self + } + + pub fn compress(&mut self, fin: bool, payload: Bytes) -> Result { + let mut output = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + loop { + let total_in = self.compress.total_in() - self.total_bytes_read; + let res = if total_in >= payload.len() as u64 { + self.compress + .compress(&[], &mut buf, flate2::FlushCompress::Sync) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + } else { + self.compress + .compress(&payload, &mut buf, flate2::FlushCompress::None) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + }; + + let written = self.compress.total_out() - self.total_bytes_written; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + self.total_bytes_read = self.compress.total_in(); + + if output.iter().rev().take(4).eq(&[0xff, 0xff, 0x00, 0x00]) { + output.drain(output.len() - 4..); + } + + if fin && self.remote_no_context_takeover { + self.reset(); + } + + Ok(output.into()) + } + + fn reset(&mut self) { + self.compress.reset(); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +#[derive(Debug)] +pub struct DeflateContext { + pub compress: DeflateCompressionContext, + pub decompress: DeflateDecompressionContext, +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index e166f1cf5..0bd64a465 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -237,18 +237,22 @@ mod tests { struct F { finished: bool, opcode: OpCode, + rsv_bits: RsvBits, payload: Bytes, } - fn is_none(frm: &Result)>, ProtocolError>) -> bool { + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { matches!(*frm, Ok(None)) } - fn extract(frm: Result)>, ProtocolError>) -> F { + fn extract(frm: Result)>, ProtocolError>) -> F { match frm { - Ok(Some((finished, opcode, payload))) => F { + Ok(Some((finished, opcode, rsv_bits, payload))) => F { finished, opcode, + rsv_bits, payload: payload .map(|b| b.freeze()) .unwrap_or_else(|| Bytes::from("")), @@ -269,6 +273,17 @@ mod tests { assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1"[..]); + + let mut buf = BytesMut::from(&[0b1111_0001u8, 0b0000_0001u8][..]); + buf.extend(b"2"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"2"[..]); + assert!(frame.rsv_bits.contains(RsvBits::RSV1)); + assert!(frame.rsv_bits.contains(RsvBits::RSV2)); + assert!(frame.rsv_bits.contains(RsvBits::RSV3)); } #[test] @@ -377,7 +392,14 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Ping, + RsvBits::empty(), + true, + false, + ); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -387,7 +409,14 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Pong, + RsvBits::empty(), + true, + false, + ); let mut v = vec![138u8, 4u8]; v.extend(b"data"); @@ -398,7 +427,7 @@ mod tests { fn test_close_frame() { let mut buf = BytesMut::new(); let reason = (CloseCode::Normal, "data"); - Parser::write_close(&mut buf, Some(reason.into()), false); + Parser::write_close(&mut buf, Some(reason.into()), RsvBits::empty(), false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); @@ -408,7 +437,7 @@ mod tests { #[test] fn test_empty_close_frame() { let mut buf = BytesMut::new(); - Parser::write_close(&mut buf, None, false); + Parser::write_close(&mut buf, None, RsvBits::empty(), false); assert_eq!(&buf[..], &vec![0x88, 0x00][..]); } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 811e63474..90ef4a932 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -11,11 +11,15 @@ use http::{header, Method, StatusCode}; use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder}; mod codec; +#[cfg(feature = "compress-ws-deflate")] +mod deflate; mod dispatcher; mod frame; mod mask; mod proto; +#[cfg(feature = "compress-ws-deflate")] +pub use self::deflate::{DeflateCompressionLevel, DeflateServerConfig, DeflateSessionParameters}; pub use self::{ codec::{Codec, Frame, Item, Message}, dispatcher::Dispatcher, @@ -93,6 +97,11 @@ pub enum HandshakeError { /// WebSocket key is not set or wrong. #[display("unknown WebSocket key")] BadWebsocketKey, + + /// Invalid `permessage-deflate` request. + #[cfg(feature = "compress-ws-deflate")] + #[display(fmt = "invalid WebSocket `permessage-deflate` extension request")] + BadDeflateRequest(deflate::DeflateHandshakeError), } impl From for Response { @@ -135,6 +144,13 @@ impl From for Response { res.head_mut().reason = Some("Handshake error"); res } + + #[cfg(feature = "compress-ws-deflate")] + HandshakeError::BadDeflateRequest(_) => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("Invalid permessage-deflate request"); + res + } } } } @@ -151,6 +167,60 @@ pub fn handshake(req: &RequestHead) -> Result { Ok(handshake_response(req)) } +/// Verify WebSocket handshake request with DEFLATE compression configurations. +#[cfg(feature = "compress-ws-deflate")] +pub fn handshake_with_deflate( + req: &RequestHead, + config: &deflate::DeflateServerConfig, +) -> Result<(ResponseBuilder, Option), HandshakeError> { + verify_handshake(req)?; + + let mut available_configurations = vec![]; + for header in req.headers().get_all(header::SEC_WEBSOCKET_EXTENSIONS) { + let Ok(header_str) = header.to_str() else { + continue; + }; + + available_configurations.extend(deflate::DeflateSessionParameters::from_extension_header( + header_str, + )); + } + + let mut selected_config = None; + let mut selected_error = None; + for config in available_configurations { + match config { + Ok(v) => { + selected_config = Some(v); + break; + } + Err(e) => { + if selected_error.is_none() { + selected_error = Some(e); + } else { + selected_error = + Some(deflate::DeflateHandshakeError::NoSuitableConfigurationFound); + } + } + } + } + + if let Some(selected_error) = selected_error { + Err(HandshakeError::BadDeflateRequest(selected_error)) + } else { + let mut response = handshake_response(req); + + if let Some(selected_config) = selected_config { + let param = config.negotiate(selected_config); + let context = param.create_context(config.compression_level, false); + response.insert_header(param); + Ok((response, Some(context))) + } else { + Ok((response, None)) + } + } +} + /// Verify WebSocket handshake request. pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { // WebSocket accepts only GET @@ -196,6 +266,7 @@ pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } + Ok(()) } diff --git a/awc/Cargo.toml b/awc/Cargo.toml index c09f32ac8..95db5e761 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -63,9 +63,15 @@ rustls-0_20 = ["tls-rustls-0_20", "actix-tls/rustls-0_20"] # TLS via Rustls v0.21 rustls-0_21 = ["tls-rustls-0_21", "actix-tls/rustls-0_21"] # TLS via Rustls v0.22 (WebPKI roots) -rustls-0_22-webpki-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-webpki-roots"] +rustls-0_22-webpki-roots = [ + "tls-rustls-0_22", + "actix-tls/rustls-0_22-webpki-roots", +] # TLS via Rustls v0.22 (Native roots) -rustls-0_22-native-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-native-roots"] +rustls-0_22-native-roots = [ + "tls-rustls-0_22", + "actix-tls/rustls-0_22-native-roots", +] # TLS via Rustls v0.23 rustls-0_23 = ["tls-rustls-0_23", "actix-tls/rustls-0_23"] # TLS via Rustls v0.23 (WebPKI roots) @@ -79,6 +85,8 @@ compress-brotli = ["actix-http/compress-brotli", "__compress"] compress-gzip = ["actix-http/compress-gzip", "__compress"] # Zstd algorithm content-encoding support compress-zstd = ["actix-http/compress-zstd", "__compress"] +# Deflate compression for WebSocket +compress-ws-deflate = ["actix-http/compress-ws-deflate"] # Cookie parsing and cookie jar cookies = ["dep:cookie"] @@ -112,7 +120,7 @@ futures-util = { version = "0.3.17", default-features = false, features = ["allo h2 = "0.3.26" http = "0.2.7" itoa = "1" -log =" 0.4" +log = " 0.4" mime = "0.3" percent-encoding = "2.1" pin-project-lite = "0.2" @@ -125,8 +133,12 @@ tokio = { version = "1.24.2", features = ["sync"] } cookie = { version = "0.16", features = ["percent-encode"], optional = true } tls-openssl = { package = "openssl", version = "0.10.55", optional = true } -tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = ["dangerous_configuration"] } -tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = ["dangerous_configuration"] } +tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = [ + "dangerous_configuration", +] } +tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = [ + "dangerous_configuration", +] } tls-rustls-0_22 = { package = "rustls", version = "0.22", optional = true } tls-rustls-0_23 = { package = "rustls", version = "0.23", optional = true, default-features = false } @@ -151,7 +163,7 @@ rcgen = "0.13" rustls-pemfile = "2" tokio = { version = "1.24.2", features = ["rt-multi-thread", "macros"] } zstd = "0.13" -tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests +tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests [lints] workspace = true diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9d..376dccaa9 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -30,6 +30,8 @@ use std::{fmt, net::SocketAddr, str}; use actix_codec::Framed; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; +#[cfg(feature = "compress-ws-deflate")] +pub use actix_http::ws::{DeflateCompressionLevel, DeflateSessionParameters}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; use actix_service::Service as _; @@ -59,6 +61,9 @@ pub struct WebsocketsRequest { server_mode: bool, config: ClientConfig, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: Option, + #[cfg(feature = "cookies")] cookies: Option, } @@ -94,6 +99,8 @@ impl WebsocketsRequest { protocols: None, max_size: 65_536, server_mode: false, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: None, #[cfg(feature = "cookies")] cookies: None, } @@ -249,6 +256,22 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } + /// Enable DEFLATE compression + #[cfg(feature = "compress-ws-deflate")] + pub fn deflate( + mut self, + compression_level: Option, + params: DeflateSessionParameters, + ) -> Self { + use actix_http::header::TryIntoHeaderPair; + // Assume session parameters are always valid. + let (key, value) = params.try_into_pair().unwrap(); + + self.deflate_compression_level = compression_level; + + self.header(key, value) + } + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, @@ -409,17 +432,51 @@ impl WebsocketsRequest { return Err(WsClientError::MissingWebSocketAcceptHeader); }; - // response and ws framed - Ok(( - ClientResponse::new(head, Payload::None), - framed.into_map_codec(|_| { - if server_mode { - ws::Codec::new().max_size(max_size) + #[cfg(feature = "compress-ws-deflate")] + let framed = { + let selected_parameter = head + .headers + .get_all(header::SEC_WEBSOCKET_EXTENSIONS) + .filter_map(|header| { + if let Ok(header_str) = header.to_str() { + Some(DeflateSessionParameters::from_extension_header(header_str)) + } else { + None + } + }) + .flatten() + .filter_map(Result::ok) + .next(); + + framed.into_map_codec(move |_| { + let codec = if let Some(parameter) = selected_parameter.clone() { + let context = parameter.create_context(self.deflate_compression_level, false); + Codec::new_deflate(context) } else { - ws::Codec::new().max_size(max_size).client_mode() + Codec::new() } - }), - )) + .max_size(max_size); + + if server_mode { + codec + } else { + codec.client_mode() + } + }) + }; + #[cfg(not(feature = "compress-ws-deflate"))] + let framed = framed.into_map_codec(move |_| { + let codec = Codec::new().max_size(max_size); + + if server_mode { + codec + } else { + codec.client_mode() + } + }); + + // response and ws framed + Ok((ClientResponse::new(head, Payload::None), framed)) } }