diff --git a/actix-ws/CHANGELOG.md b/actix-ws/CHANGELOG.md index a26894008..fea469d55 100644 --- a/actix-ws/CHANGELOG.md +++ b/actix-ws/CHANGELOG.md @@ -2,10 +2,12 @@ ## Unreleased -- Take the encoded buffer when yielding bytes in the response stream rather than splitting the buffer, reducing memory use +- Add `AggregatedMessage[Stream]` types. +- Add `MessageStream::max_frame_size()` setter method. +- Add `Session::continuation()` method. +- The `Session::text()` method now receives an `impl Into`, making broadcasting text messages more efficient. - Remove type parameters from `Session::{text, binary}()` methods, replacing with equivalent `impl Trait` parameters. -- `Session::text()` now receives an `impl Into`, making broadcasting text messages more efficient. -- Allow sending continuations via `Session::continuation()` +- Reduce memory usage by `take`-ing (rather than `split`-ing) the encoded buffer when yielding bytes in the response stream. ## 0.2.5 diff --git a/actix-ws/examples/chat.rs b/actix-ws/examples/chat.rs index 67c1d1573..9d5c823df 100644 --- a/actix-ws/examples/chat.rs +++ b/actix-ws/examples/chat.rs @@ -7,7 +7,7 @@ use std::{ use actix_web::{ middleware::Logger, web, web::Html, App, HttpRequest, HttpResponse, HttpServer, Responder, }; -use actix_ws::{Message, Session}; +use actix_ws::{AggregatedMessage, Session}; use bytestring::ByteString; use futures_util::{stream::FuturesUnordered, StreamExt as _}; use tokio::sync::Mutex; @@ -65,7 +65,10 @@ async fn ws( body: web::Payload, chat: web::Data, ) -> Result { - let (response, mut session, mut stream) = actix_ws::handle(&req, body)?; + let (response, mut session, stream) = actix_ws::handle(&req, body)?; + + // increase the maximum allowed frame size to 128KiB and aggregate continuation frames + let mut stream = stream.max_frame_size(128 * 1024).aggregate_continuations(); chat.insert(session.clone()).await; tracing::info!("Inserted session"); @@ -91,30 +94,29 @@ async fn ws( }); actix_web::rt::spawn(async move { - while let Some(Ok(msg)) = stream.next().await { + while let Some(Ok(msg)) = stream.recv().await { match msg { - Message::Ping(bytes) => { + AggregatedMessage::Ping(bytes) => { if session.pong(&bytes).await.is_err() { return; } } - Message::Text(msg) => { - tracing::info!("Relaying msg: {msg}"); - chat.send(msg).await; + + AggregatedMessage::Text(string) => { + tracing::info!("Relaying text, {string}"); + chat.send(string).await; } - Message::Close(reason) => { + + AggregatedMessage::Close(reason) => { let _ = session.close(reason).await; tracing::info!("Got close, bailing"); return; } - Message::Continuation(_) => { - let _ = session.close(None).await; - tracing::info!("Got continuation, bailing"); - return; - } - Message::Pong(_) => { + + AggregatedMessage::Pong(_) => { *alive.lock().await = Instant::now(); } + _ => (), }; } diff --git a/actix-ws/src/aggregated.rs b/actix-ws/src/aggregated.rs new file mode 100644 index 000000000..e916e534b --- /dev/null +++ b/actix-ws/src/aggregated.rs @@ -0,0 +1,216 @@ +//! WebSocket stream for aggregating continuation frames. + +use std::{ + future::poll_fn, + io, mem, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use actix_http::ws::{CloseReason, Item, Message, ProtocolError}; +use actix_web::web::{Bytes, BytesMut}; +use bytestring::ByteString; +use futures_core::Stream; + +use crate::MessageStream; + +pub(crate) enum ContinuationKind { + Text, + Binary, +} + +/// WebSocket message with any continuations aggregated together. +#[derive(Debug, PartialEq, Eq)] +pub enum AggregatedMessage { + /// Text message. + Text(ByteString), + + /// Binary message. + Binary(Bytes), + + /// Ping message. + Ping(Bytes), + + /// Pong message. + Pong(Bytes), + + /// Close message with optional reason. + Close(Option), +} + +/// Stream of messages from a WebSocket client, with continuations aggregated. +pub struct AggregatedMessageStream { + stream: MessageStream, + current_size: usize, + max_size: usize, + continuations: Vec, + continuation_kind: ContinuationKind, +} + +impl AggregatedMessageStream { + #[must_use] + pub(crate) fn new(stream: MessageStream) -> Self { + AggregatedMessageStream { + stream, + current_size: 0, + max_size: 1024 * 1024, + continuations: Vec::new(), + continuation_kind: ContinuationKind::Binary, + } + } + + /// Sets the maximum allowed size for aggregated continuations, in bytes. + /// + /// By default, up to 1 MiB is allowed. + /// + /// ```no_run + /// # use actix_ws::AggregatedMessageStream; + /// # async fn test(stream: AggregatedMessageStream) { + /// // increase the allowed size from 1MB to 8MB + /// let mut stream = stream.max_continuation_size(8 * 1024 * 1024); + /// + /// while let Some(Ok(msg)) = stream.recv().await { + /// // handle message + /// } + /// # } + /// ``` + #[must_use] + pub fn max_continuation_size(mut self, max_size: usize) -> Self { + self.max_size = max_size; + self + } + + /// Waits for the next item from the aggregated message stream. + /// + /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation. + /// + /// ```no_run + /// # use actix_ws::AggregatedMessageStream; + /// # async fn test(mut stream: AggregatedMessageStream) { + /// while let Some(Ok(msg)) = stream.recv().await { + /// // handle message + /// } + /// # } + /// ``` + #[must_use] + pub async fn recv(&mut self) -> Option<::Item> { + poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await + } +} + +fn size_error() -> Poll>> { + Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other( + "Exceeded maximum continuation size", + ))))) +} + +impl Stream for AggregatedMessageStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else { + return Poll::Ready(None); + }; + + match msg { + Message::Continuation(item) => match item { + Item::FirstText(bytes) => { + this.continuation_kind = ContinuationKind::Text; + this.current_size += bytes.len(); + + if this.current_size > this.max_size { + this.continuations.clear(); + return size_error(); + } + + this.continuations.push(bytes); + + Poll::Pending + } + + Item::FirstBinary(bytes) => { + this.continuation_kind = ContinuationKind::Binary; + this.current_size += bytes.len(); + + if this.current_size > this.max_size { + this.continuations.clear(); + return size_error(); + } + + this.continuations.push(bytes); + + Poll::Pending + } + + Item::Continue(bytes) => { + this.current_size += bytes.len(); + + if this.current_size > this.max_size { + this.continuations.clear(); + return size_error(); + } + + this.continuations.push(bytes); + + Poll::Pending + } + + Item::Last(bytes) => { + this.current_size += bytes.len(); + + if this.current_size > this.max_size { + // reset current_size, as this is the last message for + // the current continuation + this.current_size = 0; + this.continuations.clear(); + + return size_error(); + } + + this.continuations.push(bytes); + let bytes = collect(&mut this.continuations); + + this.current_size = 0; + + match this.continuation_kind { + ContinuationKind::Text => { + Poll::Ready(Some(match ByteString::try_from(bytes) { + Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)), + Err(err) => Err(ProtocolError::Io(io::Error::new( + io::ErrorKind::InvalidData, + err.to_string(), + ))), + })) + } + ContinuationKind::Binary => { + Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes)))) + } + } + } + }, + + Message::Text(text) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))), + Message::Binary(binary) => Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary)))), + Message::Ping(ping) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))), + Message::Pong(pong) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))), + Message::Close(close) => Poll::Ready(Some(Ok(AggregatedMessage::Close(close)))), + + Message::Nop => unreachable!("MessageStream should not produce no-ops"), + } + } +} + +fn collect(continuations: &mut Vec) -> Bytes { + let continuations = mem::take(continuations); + let total_len = continuations.iter().map(|b| b.len()).sum(); + + let mut buf = BytesMut::with_capacity(total_len); + + for chunk in continuations { + buf.extend(chunk); + } + + buf.freeze() +} diff --git a/actix-ws/src/fut.rs b/actix-ws/src/fut.rs index d62b9b232..3dd178a43 100644 --- a/actix-ws/src/fut.rs +++ b/actix-ws/src/fut.rs @@ -15,9 +15,12 @@ use actix_web::{ web::{Bytes, BytesMut}, Error, }; +use bytestring::ByteString; use futures_core::stream::Stream; use tokio::sync::mpsc::Receiver; +use crate::AggregatedMessageStream; + /// A response body for Websocket HTTP Requests pub struct StreamingBody { session_rx: Receiver, @@ -28,18 +31,6 @@ pub struct StreamingBody { closing: bool, } -/// A stream of Messages from a websocket client -/// -/// Messages can be accessed via the stream's `.next()` method -pub struct MessageStream { - payload: Payload, - - messages: VecDeque, - buf: BytesMut, - codec: Codec, - closing: bool, -} - impl StreamingBody { pub(super) fn new(session_rx: Receiver) -> Self { StreamingBody { @@ -52,6 +43,16 @@ impl StreamingBody { } } +/// Stream of Messages from a websocket client. +pub struct MessageStream { + payload: Payload, + + messages: VecDeque, + buf: BytesMut, + codec: Codec, + closing: bool, +} + impl MessageStream { pub(super) fn new(payload: Payload) -> Self { MessageStream { @@ -63,7 +64,40 @@ impl MessageStream { } } - /// Wait for the next item from the message stream + /// Sets the maximum permitted size for received WebSocket frames, in bytes. + /// + /// By default, up to 64KiB is allowed. + /// + /// Any received frames larger than the permitted value will return + /// `Err(ProtocolError::Overflow)` instead. + /// + /// ```no_run + /// # use actix_ws::MessageStream; + /// # fn test(stream: MessageStream) { + /// // increase permitted frame size from 64KB to 1MB + /// let stream = stream.max_frame_size(1024 * 1024); + /// # } + /// ``` + #[must_use] + pub fn max_frame_size(mut self, max_size: usize) -> Self { + self.codec = self.codec.max_size(max_size); + self + } + + /// Returns a stream wrapper that collects continuation frames into their equivalent aggregated + /// forms, i.e., binary or text. + /// + /// By default, continuations will be aggregated up to 1MiB in size (customizable with + /// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an + /// error if this size is exceeded. + #[must_use] + pub fn aggregate_continuations(self) -> AggregatedMessageStream { + AggregatedMessageStream::new(self) + } + + /// Waits for the next item from the message stream + /// + /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation. /// /// ```no_run /// # use actix_ws::MessageStream; @@ -73,6 +107,7 @@ impl MessageStream { /// } /// # } /// ``` + #[must_use] pub async fn recv(&mut self) -> Option> { poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await } @@ -135,11 +170,8 @@ impl Stream for MessageStream { Poll::Ready(Some(Ok(bytes))) => { this.buf.extend_from_slice(&bytes); } - Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new( - io::ErrorKind::Other, - e.to_string(), - ))))); + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err))))); } Poll::Ready(None) => { this.closing = true; @@ -154,12 +186,11 @@ impl Stream for MessageStream { while let Some(frame) = this.codec.decode(&mut this.buf)? { let message = match frame { Frame::Text(bytes) => { - let s = std::str::from_utf8(&bytes) - .map_err(|e| { - ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string())) + ByteString::try_from(bytes) + .map(Message::Text) + .map_err(|err| { + ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err)) })? - .to_string(); - Message::Text(s.into()) } Frame::Binary(bytes) => Message::Binary(bytes), Frame::Ping(bytes) => Message::Ping(bytes), diff --git a/actix-ws/src/lib.rs b/actix-ws/src/lib.rs index 2e78e7625..8a72c6d4e 100644 --- a/actix-ws/src/lib.rs +++ b/actix-ws/src/lib.rs @@ -15,10 +15,12 @@ use actix_http::{ use actix_web::{web, HttpRequest, HttpResponse}; use tokio::sync::mpsc::channel; +mod aggregated; mod fut; mod session; pub use self::{ + aggregated::{AggregatedMessage, AggregatedMessageStream}, fut::{MessageStream, StreamingBody}, session::{Closed, Session}, };