1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-18 05:41:50 +01:00

refactor: multipart tweaks

This commit is contained in:
Rob Ede 2024-07-04 04:53:10 +01:00
parent 00c185f617
commit 210c9a5eb3
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
7 changed files with 169 additions and 135 deletions

View File

@ -10,7 +10,7 @@ use derive_more::{Display, Error, From};
/// A set of errors that can occur during parsing multipart streams.
#[derive(Debug, Display, From, Error)]
#[non_exhaustive]
pub enum MultipartError {
pub enum Error {
/// Could not find Content-Type header.
#[display(fmt = "Could not find Content-Type header")]
ContentTypeMissing,
@ -95,11 +95,11 @@ pub enum MultipartError {
}
/// Return `BadRequest` for `MultipartError`.
impl ResponseError for MultipartError {
impl ResponseError for Error {
fn status_code(&self) -> StatusCode {
match &self {
MultipartError::Field { source, .. } => source.as_response_error().status_code(),
MultipartError::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE,
Error::Field { source, .. } => source.as_response_error().status_code(),
Error::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE,
_ => StatusCode::BAD_REQUEST,
}
}
@ -111,7 +111,7 @@ mod tests {
#[test]
fn test_multipart_error() {
let resp = MultipartError::BoundaryMissing.error_response();
let resp = Error::BoundaryMissing.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
}

View File

@ -12,11 +12,11 @@ use crate::server::Multipart;
/// # Examples
///
/// ```
/// use actix_web::{web, HttpResponse, Error};
/// use actix_web::{web, HttpResponse};
/// use actix_multipart::Multipart;
/// use futures_util::StreamExt as _;
///
/// async fn index(mut payload: Multipart) -> Result<HttpResponse, Error> {
/// async fn index(mut payload: Multipart) -> actix_web::Result<HttpResponse> {
/// // iterate over multipart stream
/// while let Some(item) = payload.next().await {
/// let mut field = item?;
@ -27,7 +27,7 @@ use crate::server::Multipart;
/// }
/// }
///
/// Ok(HttpResponse::Ok().into())
/// Ok(HttpResponse::Ok().finish())
/// }
/// ```
impl FromRequest for Multipart {

View File

@ -15,7 +15,7 @@ use futures_core::stream::Stream;
use mime::Mime;
use crate::{
error::MultipartError,
error::Error,
payload::{PayloadBuffer, PayloadRef},
safety::Safety,
};
@ -106,7 +106,7 @@ impl Field {
}
impl Stream for Field {
type Item = Result<Bytes, MultipartError>;
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
@ -122,7 +122,7 @@ impl Stream for Field {
buffer.poll_stream(cx)?;
} else if !this.safety.is_clean() {
// safety violation
return Poll::Ready(Some(Err(MultipartError::NotConsumed)));
return Poll::Ready(Some(Err(Error::NotConsumed)));
} else {
return Poll::Pending;
}
@ -192,7 +192,7 @@ impl InnerField {
pub(crate) fn read_len(
payload: &mut PayloadBuffer,
size: &mut u64,
) -> Poll<Option<Result<Bytes, MultipartError>>> {
) -> Poll<Option<Result<Bytes, Error>>> {
if *size == 0 {
Poll::Ready(None)
} else {
@ -208,7 +208,7 @@ impl InnerField {
}
None => {
if payload.eof && (*size != 0) {
Poll::Ready(Some(Err(MultipartError::Incomplete)))
Poll::Ready(Some(Err(Error::Incomplete)))
} else {
Poll::Pending
}
@ -223,13 +223,13 @@ impl InnerField {
pub(crate) fn read_stream(
payload: &mut PayloadBuffer,
boundary: &str,
) -> Poll<Option<Result<Bytes, MultipartError>>> {
) -> Poll<Option<Result<Bytes, Error>>> {
let mut pos = 0;
let len = payload.buf.len();
if len == 0 {
return if payload.eof {
Poll::Ready(Some(Err(MultipartError::Incomplete)))
Poll::Ready(Some(Err(Error::Incomplete)))
} else {
Poll::Pending
};
@ -293,7 +293,7 @@ impl InnerField {
}
}
pub(crate) fn poll(&mut self, safety: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> {
pub(crate) fn poll(&mut self, safety: &Safety) -> Poll<Option<Result<Bytes, Error>>> {
if self.payload.is_none() {
return Poll::Ready(None);
}

View File

@ -63,4 +63,4 @@ pub(crate) mod safety;
mod server;
pub mod test;
pub use self::{error::MultipartError, field::Field, server::Multipart};
pub use self::{error::Error as MultipartError, field::Field, server::Multipart};

View File

@ -1,6 +1,6 @@
use std::{
cell::{RefCell, RefMut},
cmp,
cmp, mem,
pin::Pin,
rc::Rc,
task::{Context, Poll},
@ -12,7 +12,7 @@ use actix_web::{
};
use futures_core::stream::{LocalBoxStream, Stream};
use crate::{error::MultipartError, safety::Safety};
use crate::{error::Error, safety::Safety};
pub(crate) struct PayloadRef {
payload: Rc<RefCell<PayloadBuffer>>,
@ -21,7 +21,7 @@ pub(crate) struct PayloadRef {
impl PayloadRef {
pub(crate) fn new(payload: PayloadBuffer) -> PayloadRef {
PayloadRef {
payload: Rc::new(payload.into()),
payload: Rc::new(RefCell::new(payload)),
}
}
@ -44,28 +44,33 @@ impl Clone for PayloadRef {
/// Payload buffer.
pub(crate) struct PayloadBuffer {
pub(crate) eof: bool,
pub(crate) buf: BytesMut,
pub(crate) stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
pub(crate) buf: BytesMut,
/// EOF flag. If true, no more payload reads will be attempted.
pub(crate) eof: bool,
}
impl PayloadBuffer {
/// Constructs new `PayloadBuffer` instance.
/// Constructs new payload buffer.
pub(crate) fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
PayloadBuffer {
eof: false,
buf: BytesMut::new(),
stream: Box::pin(stream),
buf: BytesMut::with_capacity(1_024), // pre-allocate 1KiB
eof: false,
}
}
pub(crate) fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
loop {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
Poll::Ready(Some(Ok(data))) => {
self.buf.extend_from_slice(&data);
// try to read more data
continue;
}
Poll::Ready(Some(Err(err))) => return Err(err),
Poll::Ready(None) => {
self.eof = true;
@ -76,7 +81,7 @@ impl PayloadBuffer {
}
}
/// Read exact number of bytes.
/// Reads exact number of bytes.
#[cfg(test)]
pub(crate) fn read_exact(&mut self, size: usize) -> Option<Bytes> {
if size <= self.buf.len() {
@ -86,46 +91,57 @@ impl PayloadBuffer {
}
}
pub(crate) fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, MultipartError> {
pub(crate) fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, Error> {
if !self.buf.is_empty() {
let size = cmp::min(self.buf.len() as u64, size) as usize;
Ok(Some(self.buf.split_to(size).freeze()))
} else if self.eof {
Err(MultipartError::Incomplete)
Err(Error::Incomplete)
} else {
Ok(None)
}
}
/// Read until specified ending.
pub(crate) fn read_until(&mut self, line: &[u8]) -> Result<Option<Bytes>, MultipartError> {
let res = memchr::memmem::find(&self.buf, line)
.map(|idx| self.buf.split_to(idx + line.len()).freeze());
/// Reads until specified ending.
///
/// Returns:
///
/// - `Ok(Some(chunk))` - `needle` is found, with chunk ending after needle
/// - `Err(Incomplete)` - `needle` is not found and we're at EOF
/// - `Ok(None)` - `needle` is not found otherwise
pub(crate) fn read_until(&mut self, needle: &[u8]) -> Result<Option<Bytes>, Error> {
match memchr::memmem::find(&self.buf, needle) {
// buffer exhausted and EOF without finding needle
None if self.eof => Err(Error::Incomplete),
if res.is_none() && self.eof {
Err(MultipartError::Incomplete)
} else {
Ok(res)
// needle not yet found
None => Ok(None),
// needle found, split chunk out of buf
Some(idx) => Ok(Some(self.buf.split_to(idx + needle.len()).freeze())),
}
}
/// Read bytes until new line delimiter.
pub(crate) fn readline(&mut self) -> Result<Option<Bytes>, MultipartError> {
/// Reads bytes until new line delimiter.
#[inline]
pub(crate) fn readline(&mut self) -> Result<Option<Bytes>, Error> {
self.read_until(b"\n")
}
/// Read bytes until new line delimiter or EOF.
pub(crate) fn readline_or_eof(&mut self) -> Result<Option<Bytes>, MultipartError> {
/// Reads bytes until new line delimiter or until EOF.
#[inline]
pub(crate) fn readline_or_eof(&mut self) -> Result<Option<Bytes>, Error> {
match self.readline() {
Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())),
Err(Error::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())),
line => line,
}
}
/// Put unprocessed data back to the buffer.
/// Puts unprocessed data back to the buffer.
pub(crate) fn unprocessed(&mut self, data: Bytes) {
let buf = BytesMut::from(data.as_ref());
let buf = std::mem::replace(&mut self.buf, buf);
// TODO: use BytesMut::from when it's released, see https://github.com/tokio-rs/bytes/pull/710
let buf = BytesMut::from(&data[..]);
let buf = mem::replace(&mut self.buf, buf);
self.buf.extend_from_slice(&buf);
}
}

View File

@ -18,7 +18,7 @@ use futures_core::stream::Stream;
use mime::Mime;
use crate::{
error::MultipartError,
error::Error,
field::InnerField,
payload::{PayloadBuffer, PayloadRef},
safety::Safety,
@ -33,9 +33,15 @@ const MAX_HEADERS: usize = 32;
/// implementation. `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` is
/// used for nested multipart streams.
pub struct Multipart {
flow: Flow,
safety: Safety,
inner: Option<Inner>,
error: Option<MultipartError>,
}
enum Flow {
InFlight(Inner),
/// Error container is Some until an error is returned out of the flow.
Error(Option<Error>),
}
impl Multipart {
@ -59,24 +65,22 @@ impl Multipart {
}
/// Extract Content-Type and boundary info from headers.
pub(crate) fn find_ct_and_boundary(
headers: &HeaderMap,
) -> Result<(Mime, String), MultipartError> {
pub(crate) fn find_ct_and_boundary(headers: &HeaderMap) -> Result<(Mime, String), Error> {
let content_type = headers
.get(&header::CONTENT_TYPE)
.ok_or(MultipartError::ContentTypeMissing)?
.ok_or(Error::ContentTypeMissing)?
.to_str()
.ok()
.and_then(|content_type| content_type.parse::<Mime>().ok())
.ok_or(MultipartError::ContentTypeParse)?;
.ok_or(Error::ContentTypeParse)?;
if content_type.type_() != mime::MULTIPART {
return Err(MultipartError::ContentTypeIncompatible);
return Err(Error::ContentTypeIncompatible);
}
let boundary = content_type
.get_param(mime::BOUNDARY)
.ok_or(MultipartError::BoundaryMissing)?
.ok_or(Error::BoundaryMissing)?
.as_str()
.to_owned();
@ -90,64 +94,57 @@ impl Multipart {
{
Multipart {
safety: Safety::new(),
inner: Some(Inner {
flow: Flow::InFlight(Inner {
payload: PayloadRef::new(PayloadBuffer::new(stream)),
content_type: ct,
boundary,
state: State::FirstBoundary,
item: Item::None,
}),
error: None,
}
}
/// Constructs a new multipart reader from given `MultipartError`.
pub(crate) fn from_error(err: MultipartError) -> Multipart {
pub(crate) fn from_error(err: Error) -> Multipart {
Multipart {
error: Some(err),
flow: Flow::Error(Some(err)),
safety: Safety::new(),
inner: None,
}
}
/// Return requests parsed Content-Type or raise the stored error.
pub(crate) fn content_type_or_bail(&mut self) -> Result<mime::Mime, MultipartError> {
if let Some(err) = self.error.take() {
return Err(err);
pub(crate) fn content_type_or_bail(&mut self) -> Result<mime::Mime, Error> {
match self.flow {
Flow::InFlight(ref inner) => Ok(inner.content_type.clone()),
Flow::Error(ref mut err) => Err(err
.take()
.expect("error should not be taken after it was returned")),
}
Ok(self
.inner
.as_ref()
// TODO: look into using enum instead of two options
.expect("multipart requests should have state")
.content_type
.clone())
}
}
impl Stream for Multipart {
type Item = Result<Field, MultipartError>;
type Item = Result<Field, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut() {
Some(inner) => {
match this.flow {
Flow::InFlight(ref mut inner) => {
if let Some(mut buffer) = inner.payload.get_mut(&this.safety) {
// check safety and poll read payload to buffer.
buffer.poll_stream(cx)?;
} else if !this.safety.is_clean() {
// safety violation
return Poll::Ready(Some(Err(MultipartError::NotConsumed)));
return Poll::Ready(Some(Err(Error::NotConsumed)));
} else {
return Poll::Pending;
}
inner.poll(&this.safety, cx)
}
None => Poll::Ready(Some(Err(this
.error
Flow::Error(ref mut err) => Poll::Ready(Some(Err(err
.take()
.expect("Multipart polled after finish")))),
}
@ -191,22 +188,21 @@ struct Inner {
}
impl Inner {
fn read_field_headers(
payload: &mut PayloadBuffer,
) -> Result<Option<HeaderMap>, MultipartError> {
fn read_field_headers(payload: &mut PayloadBuffer) -> Result<Option<HeaderMap>, Error> {
match payload.read_until(b"\r\n\r\n")? {
None => {
if payload.eof {
Err(MultipartError::Incomplete)
Err(Error::Incomplete)
} else {
Ok(None)
}
}
Some(bytes) => {
let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS];
match httparse::parse_headers(&bytes, &mut hdrs) {
Ok(httparse::Status::Complete((_, hdrs))) => {
match httparse::parse_headers(&bytes, &mut hdrs).map_err(ParseError::from)? {
httparse::Status::Complete((_, hdrs)) => {
// convert headers
let mut headers = HeaderMap::with_capacity(hdrs.len());
@ -220,57 +216,84 @@ impl Inner {
Ok(Some(headers))
}
Ok(httparse::Status::Partial) => Err(ParseError::Header.into()),
Err(err) => Err(ParseError::from(err).into()),
httparse::Status::Partial => Err(ParseError::Header.into()),
}
}
}
}
fn read_boundary(
payload: &mut PayloadBuffer,
boundary: &str,
) -> Result<Option<bool>, MultipartError> {
/// Reads a field boundary from the payload buffer (and discards it).
///
/// Reads "in-between" and "final" boundaries. E.g. for boundary = "foo":
///
/// ```plain
/// --foo <-- in-between fields
/// --foo-- <-- end of request body, should be followed by EOF
/// ```
///
/// Returns:
///
/// - `Ok(Some(true))` - final field boundary read (EOF)
/// - `Ok(Some(false))` - field boundary read
/// - `Ok(None)` - boundary not found, more data needs reading
/// - `Err(BoundaryMissing)` - multipart boundary is missing
fn read_boundary(payload: &mut PayloadBuffer, boundary: &str) -> Result<Option<bool>, Error> {
// TODO: need to read epilogue
match payload.readline_or_eof()? {
None => {
if payload.eof {
Ok(Some(true))
} else {
Ok(None)
}
}
Some(chunk) => {
if chunk.len() < boundary.len() + 4
|| &chunk[..2] != b"--"
|| &chunk[2..boundary.len() + 2] != boundary.as_bytes()
{
Err(MultipartError::BoundaryMissing)
} else if &chunk[boundary.len() + 2..] == b"\r\n" {
Ok(Some(false))
} else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--"
&& (chunk.len() == boundary.len() + 4
|| &chunk[boundary.len() + 4..] == b"\r\n")
{
Ok(Some(true))
} else {
Err(MultipartError::BoundaryMissing)
}
}
let chunk = match payload.readline_or_eof()? {
// TODO: this might be okay as a let Some() else return Ok(None)
None => return Ok(payload.eof.then_some(true)),
Some(chunk) => chunk,
};
const BOUNDARY_MARKER: &[u8] = b"--";
const LINE_BREAK: &[u8] = b"\r\n";
let boundary_len = boundary.len();
if chunk.len() < boundary_len + 2 + 2
|| !chunk.starts_with(BOUNDARY_MARKER)
|| &chunk[2..boundary_len + 2] != boundary.as_bytes()
{
return Err(Error::BoundaryMissing);
}
// chunk facts:
// - long enough to contain boundary + 2 markers or 1 marker and line-break
// - starts with boundary marker
// - chunk contains correct boundary
if &chunk[boundary_len + 2..] == LINE_BREAK {
// boundary is followed by line-break, indicating more fields to come
return Ok(Some(false));
}
// boundary is followed by marker
if &chunk[boundary_len + 2..boundary_len + 4] == BOUNDARY_MARKER
&& (
// chunk is exactly boundary len + 2 markers
chunk.len() == boundary_len + 2 + 2
// final boundary is allowed to end with a line-break
|| &chunk[boundary_len + 4..] == LINE_BREAK
)
{
return Ok(Some(true));
}
Err(Error::BoundaryMissing)
}
fn skip_until_boundary(
payload: &mut PayloadBuffer,
boundary: &str,
) -> Result<Option<bool>, MultipartError> {
) -> Result<Option<bool>, Error> {
let mut eof = false;
loop {
match payload.readline()? {
Some(chunk) => {
if chunk.is_empty() {
return Err(MultipartError::BoundaryMissing);
return Err(Error::BoundaryMissing);
}
if chunk.len() < boundary.len() {
continue;
@ -292,7 +315,7 @@ impl Inner {
}
None => {
return if payload.eof {
Err(MultipartError::Incomplete)
Err(Error::Incomplete)
} else {
Ok(None)
};
@ -302,11 +325,7 @@ impl Inner {
Ok(Some(eof))
}
fn poll(
&mut self,
safety: &Safety,
cx: &Context<'_>,
) -> Poll<Option<Result<Field, MultipartError>>> {
fn poll(&mut self, safety: &Safety, cx: &Context<'_>) -> Poll<Option<Result<Field, Error>>> {
if self.state == State::Eof {
Poll::Ready(None)
} else {
@ -338,6 +357,7 @@ impl Inner {
// read until first boundary
State::FirstBoundary => {
match Inner::skip_until_boundary(&mut payload, &self.boundary)? {
None => return Poll::Pending,
Some(eof) => {
if eof {
self.state = State::Eof;
@ -346,7 +366,6 @@ impl Inner {
self.state = State::Headers;
}
}
None => return Poll::Pending,
}
}
@ -398,11 +417,11 @@ impl Inner {
// type must be set as "form-data", and it must have a name parameter.
let Some(cd) = &field_content_disposition else {
return Poll::Ready(Some(Err(MultipartError::ContentDispositionMissing)));
return Poll::Ready(Some(Err(Error::ContentDispositionMissing)));
};
let Some(field_name) = cd.get_name() else {
return Poll::Ready(Some(Err(MultipartError::ContentDispositionNameMissing)));
return Poll::Ready(Some(Err(Error::ContentDispositionNameMissing)));
};
Some(field_name.to_owned())
@ -422,7 +441,7 @@ impl Inner {
// nested multipart stream is not supported
if let Some(mime) = &field_content_type {
if mime.type_() == mime::MULTIPART {
return Poll::Ready(Some(Err(MultipartError::Nested)));
return Poll::Ready(Some(Err(Error::Nested)));
}
}
@ -475,7 +494,7 @@ mod tests {
async fn test_boundary() {
let headers = HeaderMap::new();
match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::ContentTypeMissing) => {}
Err(Error::ContentTypeMissing) => {}
_ => unreachable!("should not happen"),
}
@ -486,7 +505,7 @@ mod tests {
);
match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::ContentTypeParse) => {}
Err(Error::ContentTypeParse) => {}
_ => unreachable!("should not happen"),
}
@ -496,7 +515,7 @@ mod tests {
header::HeaderValue::from_static("multipart/mixed"),
);
match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::BoundaryMissing) => {}
Err(Error::BoundaryMissing) => {}
_ => unreachable!("should not happen"),
}
@ -831,7 +850,7 @@ mod tests {
#[actix_rt::test]
async fn test_multipart_from_error() {
let err = MultipartError::ContentTypeMissing;
let err = Error::ContentTypeMissing;
let mut multipart = Multipart::from_error(err);
assert!(multipart.next().await.unwrap().is_err())
}
@ -888,7 +907,7 @@ mod tests {
res.expect_err(
"according to RFC 7578, form-data fields require a content-disposition header"
),
MultipartError::ContentDispositionMissing
Error::ContentDispositionMissing
);
}
@ -942,7 +961,7 @@ mod tests {
let res = multipart.next().await.unwrap();
assert_matches!(
res.expect_err("according to RFC 7578, form-data fields require a name attribute"),
MultipartError::ContentDispositionNameMissing
Error::ContentDispositionNameMissing
);
}
@ -960,7 +979,7 @@ mod tests {
// should fail immediately
match field.next().await {
Some(Err(MultipartError::NotConsumed)) => {}
Some(Err(Error::NotConsumed)) => {}
_ => panic!(),
};
}

View File

@ -25,8 +25,7 @@ const BOUNDARY_PREFIX: &str = "------------------------";
///
/// ```
/// use actix_multipart::test::create_form_data_payload_and_headers;
/// use actix_web::test::TestRequest;
/// use bytes::Bytes;
/// use actix_web::{test::TestRequest, web::Bytes};
/// use memchr::memmem::find;
///
/// let (body, headers) = create_form_data_payload_and_headers(