From 5c9e6e7c1d43d29c5b51638e198ae238ef5c51f7 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sat, 6 Jul 2024 22:58:54 +0100 Subject: [PATCH] feat(multipart): add field bytes method --- actix-multipart/CHANGES.md | 1 + actix-multipart/src/field.rs | 155 ++++++++++++++++++++++++++++++++++- 2 files changed, 152 insertions(+), 4 deletions(-) diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index dea24ab9d..86cd26060 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -4,6 +4,7 @@ - Add `MultipartError::ContentTypeIncompatible` variant. - Add `MultipartError::ContentDispositionNameMissing` variant. +- Add `Field::bytes()` method. - Rename `MultipartError::{NoContentDisposition => ContentDispositionMissing}` variant. - Rename `MultipartError::{NoContentType => ContentTypeMissing}` variant. - Rename `MultipartError::{ParseContentType => ContentTypeParse}` variant. diff --git a/actix-multipart/src/field.rs b/actix-multipart/src/field.rs index 50660b5d3..6bb1e5265 100644 --- a/actix-multipart/src/field.rs +++ b/actix-multipart/src/field.rs @@ -1,17 +1,19 @@ use std::{ cell::RefCell, - cmp, fmt, + cmp, fmt, mem, pin::Pin, rc::Rc, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; +use actix_utils::future::poll_fn; use actix_web::{ error::PayloadError, http::header::{self, ContentDisposition, HeaderMap}, - web::Bytes, + web::{Bytes, BytesMut}, }; -use futures_core::stream::Stream; +use derive_more::{Display, Error}; +use futures_core::Stream; use mime::Mime; use crate::{ @@ -20,6 +22,10 @@ use crate::{ safety::Safety, }; +#[derive(Debug, Display, Error)] +#[display(fmt = "limit exceeded")] +pub struct LimitExceeded; + /// A single field in a multipart stream. pub struct Field { /// Field's Content-Type. @@ -103,6 +109,56 @@ impl Field { pub fn name(&self) -> Option<&str> { self.content_disposition()?.get_name() } + + /// Collects the raw field data, up to `limit` bytes. + /// + /// # Errors + /// + /// Any errors produced by the data stream are returned as `Ok(Err(Error))` immediately. + /// + /// If the buffered data size would exceed `limit`, an `Err(LimitExceeded)` is returned. Note + /// that, in this case, the full data stream is exhausted before returning the error so that + /// subsequent fields can still be read. To better defend against malicious/infinite requests, + /// it is advisable to also put a timeout on this call. + pub async fn bytes(&mut self, limit: usize) -> Result, LimitExceeded> { + /// Sensible default (2kB) for initial, bounded allocation when collecting body bytes. + const INITIAL_ALLOC_BYTES: usize = 2 * 1024; + + let mut exceeded_limit = false; + let mut buf = BytesMut::with_capacity(INITIAL_ALLOC_BYTES); + + let mut field = Pin::new(self); + + match poll_fn(|cx| loop { + match ready!(field.as_mut().poll_next(cx)) { + // if already over limit, discard chunk to advance multipart request + Some(Ok(_chunk)) if exceeded_limit => {} + + // if limit is exceeded set flag to true and continue + Some(Ok(chunk)) if buf.len() + chunk.len() > limit => { + exceeded_limit = true; + // eagerly de-allocate field data buffer + let _ = mem::take(&mut buf); + } + + Some(Ok(chunk)) => buf.extend_from_slice(&chunk), + + None => return Poll::Ready(Ok(())), + Some(Err(err)) => return Poll::Ready(Err(err)), + } + }) + .await + { + // propagate error returned from body poll + Err(err) => Ok(Err(err)), + + // limit was exceeded while reading body + Ok(()) if exceeded_limit => Err(LimitExceeded), + + // otherwise return body buffer + Ok(()) => Ok(Ok(buf.freeze())), + } + } } impl Stream for Field { @@ -341,3 +397,94 @@ impl InnerField { result } } + +#[cfg(test)] +mod tests { + use futures_util::{stream, StreamExt as _}; + + use super::*; + use crate::Multipart; + + // TODO: use test utility when multi-file support is introduced + fn create_double_request_with_header() -> (Bytes, HeaderMap) { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\n\ + \r\n\ + one+one+one\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\n\ + \r\n\ + two+two+two\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + (bytes, headers) + } + + #[actix_rt::test] + async fn bytes_unlimited() { + let (body, headers) = create_double_request_with_header(); + + let mut multipart = Multipart::new(&headers, stream::iter([Ok(body)])); + + let field = multipart + .next() + .await + .expect("multipart should have two fields") + .expect("multipart body should be well formatted") + .bytes(usize::MAX) + .await + .expect("field data should not be size limited") + .expect("reading field data should not error"); + assert_eq!(field, "one+one+one"); + + let field = multipart + .next() + .await + .expect("multipart should have two fields") + .expect("multipart body should be well formatted") + .bytes(usize::MAX) + .await + .expect("field data should not be size limited") + .expect("reading field data should not error"); + assert_eq!(field, "two+two+two"); + } + + #[actix_rt::test] + async fn bytes_limited() { + let (body, headers) = create_double_request_with_header(); + + let mut multipart = Multipart::new(&headers, stream::iter([Ok(body)])); + + multipart + .next() + .await + .expect("multipart should have two fields") + .expect("multipart body should be well formatted") + .bytes(8) // smaller than data size + .await + .expect_err("field data should be size limited"); + + // next field still readable + let field = multipart + .next() + .await + .expect("multipart should have two fields") + .expect("multipart body should be well formatted") + .bytes(usize::MAX) + .await + .expect("field data should not be size limited") + .expect("reading field data should not error"); + assert_eq!(field, "two+two+two"); + } +}