mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-23 15:51:06 +01:00
Add better support for receiving larger payloads (#430)
* Add better support for receiving larger payloads This change enables the maximum frame size to be configured when receiving websocket frames. It also adds a new stream time that aggregates continuation frames together into their proper collected representation. It provides no mechanism yet for sending continuations. * actix-ws: Add continuation & size config to changelog * actix-ws: Add Debug, Eq to AggregatedMessage * actix-ws: Add a configurable maximum size to aggregated continuations * refactor: move aggregate types to own module * test: fix chat example * docs: update changelog --------- Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
parent
0802eff40d
commit
b0d2947a4a
@ -2,10 +2,12 @@
|
||||
|
||||
## Unreleased
|
||||
|
||||
- Take the encoded buffer when yielding bytes in the response stream rather than splitting the buffer, reducing memory use
|
||||
- Add `AggregatedMessage[Stream]` types.
|
||||
- Add `MessageStream::max_frame_size()` setter method.
|
||||
- Add `Session::continuation()` method.
|
||||
- The `Session::text()` method now receives an `impl Into<ByteString>`, making broadcasting text messages more efficient.
|
||||
- Remove type parameters from `Session::{text, binary}()` methods, replacing with equivalent `impl Trait` parameters.
|
||||
- `Session::text()` now receives an `impl Into<ByteString>`, making broadcasting text messages more efficient.
|
||||
- Allow sending continuations via `Session::continuation()`
|
||||
- Reduce memory usage by `take`-ing (rather than `split`-ing) the encoded buffer when yielding bytes in the response stream.
|
||||
|
||||
## 0.2.5
|
||||
|
||||
|
@ -7,7 +7,7 @@ use std::{
|
||||
use actix_web::{
|
||||
middleware::Logger, web, web::Html, App, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
};
|
||||
use actix_ws::{Message, Session};
|
||||
use actix_ws::{AggregatedMessage, Session};
|
||||
use bytestring::ByteString;
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt as _};
|
||||
use tokio::sync::Mutex;
|
||||
@ -65,7 +65,10 @@ async fn ws(
|
||||
body: web::Payload,
|
||||
chat: web::Data<Chat>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (response, mut session, mut stream) = actix_ws::handle(&req, body)?;
|
||||
let (response, mut session, stream) = actix_ws::handle(&req, body)?;
|
||||
|
||||
// increase the maximum allowed frame size to 128KiB and aggregate continuation frames
|
||||
let mut stream = stream.max_frame_size(128 * 1024).aggregate_continuations();
|
||||
|
||||
chat.insert(session.clone()).await;
|
||||
tracing::info!("Inserted session");
|
||||
@ -91,30 +94,29 @@ async fn ws(
|
||||
});
|
||||
|
||||
actix_web::rt::spawn(async move {
|
||||
while let Some(Ok(msg)) = stream.next().await {
|
||||
while let Some(Ok(msg)) = stream.recv().await {
|
||||
match msg {
|
||||
Message::Ping(bytes) => {
|
||||
AggregatedMessage::Ping(bytes) => {
|
||||
if session.pong(&bytes).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Message::Text(msg) => {
|
||||
tracing::info!("Relaying msg: {msg}");
|
||||
chat.send(msg).await;
|
||||
|
||||
AggregatedMessage::Text(string) => {
|
||||
tracing::info!("Relaying text, {string}");
|
||||
chat.send(string).await;
|
||||
}
|
||||
Message::Close(reason) => {
|
||||
|
||||
AggregatedMessage::Close(reason) => {
|
||||
let _ = session.close(reason).await;
|
||||
tracing::info!("Got close, bailing");
|
||||
return;
|
||||
}
|
||||
Message::Continuation(_) => {
|
||||
let _ = session.close(None).await;
|
||||
tracing::info!("Got continuation, bailing");
|
||||
return;
|
||||
}
|
||||
Message::Pong(_) => {
|
||||
|
||||
AggregatedMessage::Pong(_) => {
|
||||
*alive.lock().await = Instant::now();
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
|
216
actix-ws/src/aggregated.rs
Normal file
216
actix-ws/src/aggregated.rs
Normal file
@ -0,0 +1,216 @@
|
||||
//! WebSocket stream for aggregating continuation frames.
|
||||
|
||||
use std::{
|
||||
future::poll_fn,
|
||||
io, mem,
|
||||
pin::Pin,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
|
||||
use actix_web::web::{Bytes, BytesMut};
|
||||
use bytestring::ByteString;
|
||||
use futures_core::Stream;
|
||||
|
||||
use crate::MessageStream;
|
||||
|
||||
pub(crate) enum ContinuationKind {
|
||||
Text,
|
||||
Binary,
|
||||
}
|
||||
|
||||
/// WebSocket message with any continuations aggregated together.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum AggregatedMessage {
|
||||
/// Text message.
|
||||
Text(ByteString),
|
||||
|
||||
/// Binary message.
|
||||
Binary(Bytes),
|
||||
|
||||
/// Ping message.
|
||||
Ping(Bytes),
|
||||
|
||||
/// Pong message.
|
||||
Pong(Bytes),
|
||||
|
||||
/// Close message with optional reason.
|
||||
Close(Option<CloseReason>),
|
||||
}
|
||||
|
||||
/// Stream of messages from a WebSocket client, with continuations aggregated.
|
||||
pub struct AggregatedMessageStream {
|
||||
stream: MessageStream,
|
||||
current_size: usize,
|
||||
max_size: usize,
|
||||
continuations: Vec<Bytes>,
|
||||
continuation_kind: ContinuationKind,
|
||||
}
|
||||
|
||||
impl AggregatedMessageStream {
|
||||
#[must_use]
|
||||
pub(crate) fn new(stream: MessageStream) -> Self {
|
||||
AggregatedMessageStream {
|
||||
stream,
|
||||
current_size: 0,
|
||||
max_size: 1024 * 1024,
|
||||
continuations: Vec::new(),
|
||||
continuation_kind: ContinuationKind::Binary,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the maximum allowed size for aggregated continuations, in bytes.
|
||||
///
|
||||
/// By default, up to 1 MiB is allowed.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use actix_ws::AggregatedMessageStream;
|
||||
/// # async fn test(stream: AggregatedMessageStream) {
|
||||
/// // increase the allowed size from 1MB to 8MB
|
||||
/// let mut stream = stream.max_continuation_size(8 * 1024 * 1024);
|
||||
///
|
||||
/// while let Some(Ok(msg)) = stream.recv().await {
|
||||
/// // handle message
|
||||
/// }
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn max_continuation_size(mut self, max_size: usize) -> Self {
|
||||
self.max_size = max_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Waits for the next item from the aggregated message stream.
|
||||
///
|
||||
/// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use actix_ws::AggregatedMessageStream;
|
||||
/// # async fn test(mut stream: AggregatedMessageStream) {
|
||||
/// while let Some(Ok(msg)) = stream.recv().await {
|
||||
/// // handle message
|
||||
/// }
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
|
||||
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
|
||||
}
|
||||
}
|
||||
|
||||
fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
|
||||
Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
|
||||
"Exceeded maximum continuation size",
|
||||
)))))
|
||||
}
|
||||
|
||||
impl Stream for AggregatedMessageStream {
|
||||
type Item = Result<AggregatedMessage, ProtocolError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
match msg {
|
||||
Message::Continuation(item) => match item {
|
||||
Item::FirstText(bytes) => {
|
||||
this.continuation_kind = ContinuationKind::Text;
|
||||
this.current_size += bytes.len();
|
||||
|
||||
if this.current_size > this.max_size {
|
||||
this.continuations.clear();
|
||||
return size_error();
|
||||
}
|
||||
|
||||
this.continuations.push(bytes);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
Item::FirstBinary(bytes) => {
|
||||
this.continuation_kind = ContinuationKind::Binary;
|
||||
this.current_size += bytes.len();
|
||||
|
||||
if this.current_size > this.max_size {
|
||||
this.continuations.clear();
|
||||
return size_error();
|
||||
}
|
||||
|
||||
this.continuations.push(bytes);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
Item::Continue(bytes) => {
|
||||
this.current_size += bytes.len();
|
||||
|
||||
if this.current_size > this.max_size {
|
||||
this.continuations.clear();
|
||||
return size_error();
|
||||
}
|
||||
|
||||
this.continuations.push(bytes);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
Item::Last(bytes) => {
|
||||
this.current_size += bytes.len();
|
||||
|
||||
if this.current_size > this.max_size {
|
||||
// reset current_size, as this is the last message for
|
||||
// the current continuation
|
||||
this.current_size = 0;
|
||||
this.continuations.clear();
|
||||
|
||||
return size_error();
|
||||
}
|
||||
|
||||
this.continuations.push(bytes);
|
||||
let bytes = collect(&mut this.continuations);
|
||||
|
||||
this.current_size = 0;
|
||||
|
||||
match this.continuation_kind {
|
||||
ContinuationKind::Text => {
|
||||
Poll::Ready(Some(match ByteString::try_from(bytes) {
|
||||
Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
|
||||
Err(err) => Err(ProtocolError::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
err.to_string(),
|
||||
))),
|
||||
}))
|
||||
}
|
||||
ContinuationKind::Binary => {
|
||||
Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Message::Text(text) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
|
||||
Message::Binary(binary) => Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary)))),
|
||||
Message::Ping(ping) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
|
||||
Message::Pong(pong) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
|
||||
Message::Close(close) => Poll::Ready(Some(Ok(AggregatedMessage::Close(close)))),
|
||||
|
||||
Message::Nop => unreachable!("MessageStream should not produce no-ops"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
|
||||
let continuations = mem::take(continuations);
|
||||
let total_len = continuations.iter().map(|b| b.len()).sum();
|
||||
|
||||
let mut buf = BytesMut::with_capacity(total_len);
|
||||
|
||||
for chunk in continuations {
|
||||
buf.extend(chunk);
|
||||
}
|
||||
|
||||
buf.freeze()
|
||||
}
|
@ -15,9 +15,12 @@ use actix_web::{
|
||||
web::{Bytes, BytesMut},
|
||||
Error,
|
||||
};
|
||||
use bytestring::ByteString;
|
||||
use futures_core::stream::Stream;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
|
||||
use crate::AggregatedMessageStream;
|
||||
|
||||
/// A response body for Websocket HTTP Requests
|
||||
pub struct StreamingBody {
|
||||
session_rx: Receiver<Message>,
|
||||
@ -28,18 +31,6 @@ pub struct StreamingBody {
|
||||
closing: bool,
|
||||
}
|
||||
|
||||
/// A stream of Messages from a websocket client
|
||||
///
|
||||
/// Messages can be accessed via the stream's `.next()` method
|
||||
pub struct MessageStream {
|
||||
payload: Payload,
|
||||
|
||||
messages: VecDeque<Message>,
|
||||
buf: BytesMut,
|
||||
codec: Codec,
|
||||
closing: bool,
|
||||
}
|
||||
|
||||
impl StreamingBody {
|
||||
pub(super) fn new(session_rx: Receiver<Message>) -> Self {
|
||||
StreamingBody {
|
||||
@ -52,6 +43,16 @@ impl StreamingBody {
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream of Messages from a websocket client.
|
||||
pub struct MessageStream {
|
||||
payload: Payload,
|
||||
|
||||
messages: VecDeque<Message>,
|
||||
buf: BytesMut,
|
||||
codec: Codec,
|
||||
closing: bool,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
pub(super) fn new(payload: Payload) -> Self {
|
||||
MessageStream {
|
||||
@ -63,7 +64,40 @@ impl MessageStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for the next item from the message stream
|
||||
/// Sets the maximum permitted size for received WebSocket frames, in bytes.
|
||||
///
|
||||
/// By default, up to 64KiB is allowed.
|
||||
///
|
||||
/// Any received frames larger than the permitted value will return
|
||||
/// `Err(ProtocolError::Overflow)` instead.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use actix_ws::MessageStream;
|
||||
/// # fn test(stream: MessageStream) {
|
||||
/// // increase permitted frame size from 64KB to 1MB
|
||||
/// let stream = stream.max_frame_size(1024 * 1024);
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn max_frame_size(mut self, max_size: usize) -> Self {
|
||||
self.codec = self.codec.max_size(max_size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns a stream wrapper that collects continuation frames into their equivalent aggregated
|
||||
/// forms, i.e., binary or text.
|
||||
///
|
||||
/// By default, continuations will be aggregated up to 1MiB in size (customizable with
|
||||
/// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an
|
||||
/// error if this size is exceeded.
|
||||
#[must_use]
|
||||
pub fn aggregate_continuations(self) -> AggregatedMessageStream {
|
||||
AggregatedMessageStream::new(self)
|
||||
}
|
||||
|
||||
/// Waits for the next item from the message stream
|
||||
///
|
||||
/// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use actix_ws::MessageStream;
|
||||
@ -73,6 +107,7 @@ impl MessageStream {
|
||||
/// }
|
||||
/// # }
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
|
||||
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
|
||||
}
|
||||
@ -135,11 +170,8 @@ impl Stream for MessageStream {
|
||||
Poll::Ready(Some(Ok(bytes))) => {
|
||||
this.buf.extend_from_slice(&bytes);
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)))));
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
this.closing = true;
|
||||
@ -154,12 +186,11 @@ impl Stream for MessageStream {
|
||||
while let Some(frame) = this.codec.decode(&mut this.buf)? {
|
||||
let message = match frame {
|
||||
Frame::Text(bytes) => {
|
||||
let s = std::str::from_utf8(&bytes)
|
||||
.map_err(|e| {
|
||||
ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string()))
|
||||
ByteString::try_from(bytes)
|
||||
.map(Message::Text)
|
||||
.map_err(|err| {
|
||||
ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
})?
|
||||
.to_string();
|
||||
Message::Text(s.into())
|
||||
}
|
||||
Frame::Binary(bytes) => Message::Binary(bytes),
|
||||
Frame::Ping(bytes) => Message::Ping(bytes),
|
||||
|
@ -15,10 +15,12 @@ use actix_http::{
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
use tokio::sync::mpsc::channel;
|
||||
|
||||
mod aggregated;
|
||||
mod fut;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
aggregated::{AggregatedMessage, AggregatedMessageStream},
|
||||
fut::{MessageStream, StreamingBody},
|
||||
session::{Closed, Session},
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user