1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 06:45:56 +01:00

Add better support for receiving larger payloads (#430)

* Add better support for receiving larger payloads

This change enables the maximum frame size to be configured when receiving websocket frames. It also
adds a new stream time that aggregates continuation frames together into their proper collected
representation. It provides no mechanism yet for sending continuations.

* actix-ws: Add continuation & size config to changelog

* actix-ws: Add Debug, Eq to AggregatedMessage

* actix-ws: Add a configurable maximum size to aggregated continuations

* refactor: move aggregate types to own module

* test: fix chat example

* docs: update changelog

---------

Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
asonix 2024-07-20 01:09:30 -05:00 committed by GitHub
parent 0802eff40d
commit b0d2947a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 293 additions and 40 deletions

View File

@ -2,10 +2,12 @@
## Unreleased
- Take the encoded buffer when yielding bytes in the response stream rather than splitting the buffer, reducing memory use
- Add `AggregatedMessage[Stream]` types.
- Add `MessageStream::max_frame_size()` setter method.
- Add `Session::continuation()` method.
- The `Session::text()` method now receives an `impl Into<ByteString>`, making broadcasting text messages more efficient.
- Remove type parameters from `Session::{text, binary}()` methods, replacing with equivalent `impl Trait` parameters.
- `Session::text()` now receives an `impl Into<ByteString>`, making broadcasting text messages more efficient.
- Allow sending continuations via `Session::continuation()`
- Reduce memory usage by `take`-ing (rather than `split`-ing) the encoded buffer when yielding bytes in the response stream.
## 0.2.5

View File

@ -7,7 +7,7 @@ use std::{
use actix_web::{
middleware::Logger, web, web::Html, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use actix_ws::{Message, Session};
use actix_ws::{AggregatedMessage, Session};
use bytestring::ByteString;
use futures_util::{stream::FuturesUnordered, StreamExt as _};
use tokio::sync::Mutex;
@ -65,7 +65,10 @@ async fn ws(
body: web::Payload,
chat: web::Data<Chat>,
) -> Result<HttpResponse, actix_web::Error> {
let (response, mut session, mut stream) = actix_ws::handle(&req, body)?;
let (response, mut session, stream) = actix_ws::handle(&req, body)?;
// increase the maximum allowed frame size to 128KiB and aggregate continuation frames
let mut stream = stream.max_frame_size(128 * 1024).aggregate_continuations();
chat.insert(session.clone()).await;
tracing::info!("Inserted session");
@ -91,30 +94,29 @@ async fn ws(
});
actix_web::rt::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
while let Some(Ok(msg)) = stream.recv().await {
match msg {
Message::Ping(bytes) => {
AggregatedMessage::Ping(bytes) => {
if session.pong(&bytes).await.is_err() {
return;
}
}
Message::Text(msg) => {
tracing::info!("Relaying msg: {msg}");
chat.send(msg).await;
AggregatedMessage::Text(string) => {
tracing::info!("Relaying text, {string}");
chat.send(string).await;
}
Message::Close(reason) => {
AggregatedMessage::Close(reason) => {
let _ = session.close(reason).await;
tracing::info!("Got close, bailing");
return;
}
Message::Continuation(_) => {
let _ = session.close(None).await;
tracing::info!("Got continuation, bailing");
return;
}
Message::Pong(_) => {
AggregatedMessage::Pong(_) => {
*alive.lock().await = Instant::now();
}
_ => (),
};
}

216
actix-ws/src/aggregated.rs Normal file
View File

@ -0,0 +1,216 @@
//! WebSocket stream for aggregating continuation frames.
use std::{
future::poll_fn,
io, mem,
pin::Pin,
task::{ready, Context, Poll},
};
use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
use actix_web::web::{Bytes, BytesMut};
use bytestring::ByteString;
use futures_core::Stream;
use crate::MessageStream;
pub(crate) enum ContinuationKind {
Text,
Binary,
}
/// WebSocket message with any continuations aggregated together.
#[derive(Debug, PartialEq, Eq)]
pub enum AggregatedMessage {
/// Text message.
Text(ByteString),
/// Binary message.
Binary(Bytes),
/// Ping message.
Ping(Bytes),
/// Pong message.
Pong(Bytes),
/// Close message with optional reason.
Close(Option<CloseReason>),
}
/// Stream of messages from a WebSocket client, with continuations aggregated.
pub struct AggregatedMessageStream {
stream: MessageStream,
current_size: usize,
max_size: usize,
continuations: Vec<Bytes>,
continuation_kind: ContinuationKind,
}
impl AggregatedMessageStream {
#[must_use]
pub(crate) fn new(stream: MessageStream) -> Self {
AggregatedMessageStream {
stream,
current_size: 0,
max_size: 1024 * 1024,
continuations: Vec::new(),
continuation_kind: ContinuationKind::Binary,
}
}
/// Sets the maximum allowed size for aggregated continuations, in bytes.
///
/// By default, up to 1 MiB is allowed.
///
/// ```no_run
/// # use actix_ws::AggregatedMessageStream;
/// # async fn test(stream: AggregatedMessageStream) {
/// // increase the allowed size from 1MB to 8MB
/// let mut stream = stream.max_continuation_size(8 * 1024 * 1024);
///
/// while let Some(Ok(msg)) = stream.recv().await {
/// // handle message
/// }
/// # }
/// ```
#[must_use]
pub fn max_continuation_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
/// Waits for the next item from the aggregated message stream.
///
/// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
///
/// ```no_run
/// # use actix_ws::AggregatedMessageStream;
/// # async fn test(mut stream: AggregatedMessageStream) {
/// while let Some(Ok(msg)) = stream.recv().await {
/// // handle message
/// }
/// # }
/// ```
#[must_use]
pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
}
fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
"Exceeded maximum continuation size",
)))))
}
impl Stream for AggregatedMessageStream {
type Item = Result<AggregatedMessage, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
return Poll::Ready(None);
};
match msg {
Message::Continuation(item) => match item {
Item::FirstText(bytes) => {
this.continuation_kind = ContinuationKind::Text;
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.continuations.clear();
return size_error();
}
this.continuations.push(bytes);
Poll::Pending
}
Item::FirstBinary(bytes) => {
this.continuation_kind = ContinuationKind::Binary;
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.continuations.clear();
return size_error();
}
this.continuations.push(bytes);
Poll::Pending
}
Item::Continue(bytes) => {
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.continuations.clear();
return size_error();
}
this.continuations.push(bytes);
Poll::Pending
}
Item::Last(bytes) => {
this.current_size += bytes.len();
if this.current_size > this.max_size {
// reset current_size, as this is the last message for
// the current continuation
this.current_size = 0;
this.continuations.clear();
return size_error();
}
this.continuations.push(bytes);
let bytes = collect(&mut this.continuations);
this.current_size = 0;
match this.continuation_kind {
ContinuationKind::Text => {
Poll::Ready(Some(match ByteString::try_from(bytes) {
Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
Err(err) => Err(ProtocolError::Io(io::Error::new(
io::ErrorKind::InvalidData,
err.to_string(),
))),
}))
}
ContinuationKind::Binary => {
Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
}
}
}
},
Message::Text(text) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
Message::Binary(binary) => Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary)))),
Message::Ping(ping) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
Message::Pong(pong) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
Message::Close(close) => Poll::Ready(Some(Ok(AggregatedMessage::Close(close)))),
Message::Nop => unreachable!("MessageStream should not produce no-ops"),
}
}
}
fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
let continuations = mem::take(continuations);
let total_len = continuations.iter().map(|b| b.len()).sum();
let mut buf = BytesMut::with_capacity(total_len);
for chunk in continuations {
buf.extend(chunk);
}
buf.freeze()
}

View File

@ -15,9 +15,12 @@ use actix_web::{
web::{Bytes, BytesMut},
Error,
};
use bytestring::ByteString;
use futures_core::stream::Stream;
use tokio::sync::mpsc::Receiver;
use crate::AggregatedMessageStream;
/// A response body for Websocket HTTP Requests
pub struct StreamingBody {
session_rx: Receiver<Message>,
@ -28,18 +31,6 @@ pub struct StreamingBody {
closing: bool,
}
/// A stream of Messages from a websocket client
///
/// Messages can be accessed via the stream's `.next()` method
pub struct MessageStream {
payload: Payload,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
impl StreamingBody {
pub(super) fn new(session_rx: Receiver<Message>) -> Self {
StreamingBody {
@ -52,6 +43,16 @@ impl StreamingBody {
}
}
/// Stream of Messages from a websocket client.
pub struct MessageStream {
payload: Payload,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
impl MessageStream {
pub(super) fn new(payload: Payload) -> Self {
MessageStream {
@ -63,7 +64,40 @@ impl MessageStream {
}
}
/// Wait for the next item from the message stream
/// Sets the maximum permitted size for received WebSocket frames, in bytes.
///
/// By default, up to 64KiB is allowed.
///
/// Any received frames larger than the permitted value will return
/// `Err(ProtocolError::Overflow)` instead.
///
/// ```no_run
/// # use actix_ws::MessageStream;
/// # fn test(stream: MessageStream) {
/// // increase permitted frame size from 64KB to 1MB
/// let stream = stream.max_frame_size(1024 * 1024);
/// # }
/// ```
#[must_use]
pub fn max_frame_size(mut self, max_size: usize) -> Self {
self.codec = self.codec.max_size(max_size);
self
}
/// Returns a stream wrapper that collects continuation frames into their equivalent aggregated
/// forms, i.e., binary or text.
///
/// By default, continuations will be aggregated up to 1MiB in size (customizable with
/// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an
/// error if this size is exceeded.
#[must_use]
pub fn aggregate_continuations(self) -> AggregatedMessageStream {
AggregatedMessageStream::new(self)
}
/// Waits for the next item from the message stream
///
/// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
///
/// ```no_run
/// # use actix_ws::MessageStream;
@ -73,6 +107,7 @@ impl MessageStream {
/// }
/// # }
/// ```
#[must_use]
pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
@ -135,11 +170,8 @@ impl Stream for MessageStream {
Poll::Ready(Some(Ok(bytes))) => {
this.buf.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
io::ErrorKind::Other,
e.to_string(),
)))));
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
}
Poll::Ready(None) => {
this.closing = true;
@ -154,12 +186,11 @@ impl Stream for MessageStream {
while let Some(frame) = this.codec.decode(&mut this.buf)? {
let message = match frame {
Frame::Text(bytes) => {
let s = std::str::from_utf8(&bytes)
.map_err(|e| {
ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string()))
ByteString::try_from(bytes)
.map(Message::Text)
.map_err(|err| {
ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
})?
.to_string();
Message::Text(s.into())
}
Frame::Binary(bytes) => Message::Binary(bytes),
Frame::Ping(bytes) => Message::Ping(bytes),

View File

@ -15,10 +15,12 @@ use actix_http::{
use actix_web::{web, HttpRequest, HttpResponse};
use tokio::sync::mpsc::channel;
mod aggregated;
mod fut;
mod session;
pub use self::{
aggregated::{AggregatedMessage, AggregatedMessageStream},
fut::{MessageStream, StreamingBody},
session::{Closed, Session},
};