mirror of
https://github.com/fafhrd91/actix-web
synced 2025-06-26 15:07:42 +02:00
actix-multipart: Feature: Add typed multipart form extractor (#2883)
Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
@ -1,12 +1,15 @@
|
||||
//! Error and Result module
|
||||
use actix_web::error::{ParseError, PayloadError};
|
||||
use actix_web::http::StatusCode;
|
||||
use actix_web::ResponseError;
|
||||
|
||||
use actix_web::{
|
||||
error::{ParseError, PayloadError},
|
||||
http::StatusCode,
|
||||
ResponseError,
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
|
||||
/// A set of errors that can occur during parsing multipart streams
|
||||
#[non_exhaustive]
|
||||
/// A set of errors that can occur during parsing multipart streams.
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum MultipartError {
|
||||
/// Content-Disposition header is not found or is not equal to "form-data".
|
||||
///
|
||||
@ -46,12 +49,41 @@ pub enum MultipartError {
|
||||
/// Not consumed
|
||||
#[display(fmt = "Multipart stream is not consumed")]
|
||||
NotConsumed,
|
||||
|
||||
/// An error from a field handler in a form
|
||||
#[display(
|
||||
fmt = "An error occurred processing field `{}`: {}",
|
||||
field_name,
|
||||
source
|
||||
)]
|
||||
Field {
|
||||
field_name: String,
|
||||
source: actix_web::Error,
|
||||
},
|
||||
|
||||
/// Duplicate field
|
||||
#[display(fmt = "Duplicate field found for: `{}`", _0)]
|
||||
#[from(ignore)]
|
||||
DuplicateField(#[error(not(source))] String),
|
||||
|
||||
/// Missing field
|
||||
#[display(fmt = "Field with name `{}` is required", _0)]
|
||||
#[from(ignore)]
|
||||
MissingField(#[error(not(source))] String),
|
||||
|
||||
/// Unknown field
|
||||
#[display(fmt = "Unsupported field `{}`", _0)]
|
||||
#[from(ignore)]
|
||||
UnsupportedField(#[error(not(source))] String),
|
||||
}
|
||||
|
||||
/// Return `BadRequest` for `MultipartError`
|
||||
impl ResponseError for MultipartError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
StatusCode::BAD_REQUEST
|
||||
match &self {
|
||||
MultipartError::Field { source, .. } => source.as_response_error().status_code(),
|
||||
_ => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,8 +9,7 @@ use crate::server::Multipart;
|
||||
///
|
||||
/// Content-type: multipart/form-data;
|
||||
///
|
||||
/// ## Server example
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use actix_web::{web, HttpResponse, Error};
|
||||
/// use actix_multipart::Multipart;
|
||||
|
53
actix-multipart/src/form/bytes.rs
Normal file
53
actix-multipart/src/form/bytes.rs
Normal file
@ -0,0 +1,53 @@
|
||||
//! Reads a field into memory.
|
||||
|
||||
use actix_web::HttpRequest;
|
||||
use bytes::BytesMut;
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
use futures_util::TryStreamExt as _;
|
||||
use mime::Mime;
|
||||
|
||||
use crate::{
|
||||
form::{FieldReader, Limits},
|
||||
Field, MultipartError,
|
||||
};
|
||||
|
||||
/// Read the field into memory.
|
||||
#[derive(Debug)]
|
||||
pub struct Bytes {
|
||||
/// The data.
|
||||
pub data: bytes::Bytes,
|
||||
|
||||
/// The value of the `Content-Type` header.
|
||||
pub content_type: Option<Mime>,
|
||||
|
||||
/// The `filename` value in the `Content-Disposition` header.
|
||||
pub file_name: Option<String>,
|
||||
}
|
||||
|
||||
impl<'t> FieldReader<'t> for Bytes {
|
||||
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
|
||||
|
||||
fn read_field(
|
||||
_: &'t HttpRequest,
|
||||
mut field: Field,
|
||||
limits: &'t mut Limits,
|
||||
) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let mut buf = BytesMut::with_capacity(131_072);
|
||||
|
||||
while let Some(chunk) = field.try_next().await? {
|
||||
limits.try_consume_limits(chunk.len(), true)?;
|
||||
buf.extend(chunk);
|
||||
}
|
||||
|
||||
Ok(Bytes {
|
||||
data: buf.freeze(),
|
||||
content_type: field.content_type().map(ToOwned::to_owned),
|
||||
file_name: field
|
||||
.content_disposition()
|
||||
.get_filename()
|
||||
.map(str::to_owned),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
195
actix-multipart/src/form/json.rs
Normal file
195
actix-multipart/src/form/json.rs
Normal file
@ -0,0 +1,195 @@
|
||||
//! Deserializes a field as JSON.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
|
||||
use derive_more::{Deref, DerefMut, Display, Error};
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::{
|
||||
form::{bytes::Bytes, FieldReader, Limits},
|
||||
Field, MultipartError,
|
||||
};
|
||||
|
||||
use super::FieldErrorHandler;
|
||||
|
||||
/// Deserialize from JSON.
|
||||
#[derive(Debug, Deref, DerefMut)]
|
||||
pub struct Json<T: DeserializeOwned>(pub T);
|
||||
|
||||
impl<T: DeserializeOwned> Json<T> {
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t, T> FieldReader<'t> for Json<T>
|
||||
where
|
||||
T: DeserializeOwned + 'static,
|
||||
{
|
||||
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
|
||||
|
||||
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let config = JsonConfig::from_req(req);
|
||||
let field_name = field.name().to_owned();
|
||||
|
||||
if config.validate_content_type {
|
||||
let valid = if let Some(mime) = field.content_type() {
|
||||
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if !valid {
|
||||
return Err(MultipartError::Field {
|
||||
field_name,
|
||||
source: config.map_error(req, JsonFieldError::ContentType),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = Bytes::read_field(req, field, limits).await?;
|
||||
|
||||
Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err(
|
||||
|err| MultipartError::Field {
|
||||
field_name,
|
||||
source: config.map_error(req, JsonFieldError::Deserialize(err)),
|
||||
},
|
||||
)?))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum JsonFieldError {
|
||||
/// Deserialize error.
|
||||
#[display(fmt = "Json deserialize error: {}", _0)]
|
||||
Deserialize(serde_json::Error),
|
||||
|
||||
/// Content type error.
|
||||
#[display(fmt = "Content type error")]
|
||||
ContentType,
|
||||
}
|
||||
|
||||
impl ResponseError for JsonFieldError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
StatusCode::BAD_REQUEST
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the [`Json`] field reader.
|
||||
#[derive(Clone)]
|
||||
pub struct JsonConfig {
|
||||
err_handler: FieldErrorHandler<JsonFieldError>,
|
||||
validate_content_type: bool,
|
||||
}
|
||||
|
||||
const DEFAULT_CONFIG: JsonConfig = JsonConfig {
|
||||
err_handler: None,
|
||||
validate_content_type: true,
|
||||
};
|
||||
|
||||
impl JsonConfig {
|
||||
pub fn error_handler<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(JsonFieldError, &HttpRequest) -> Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err_handler = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Extract payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
|
||||
/// back to the default payload config.
|
||||
fn from_req(req: &HttpRequest) -> &Self {
|
||||
req.app_data::<Self>()
|
||||
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
|
||||
.unwrap_or(&DEFAULT_CONFIG)
|
||||
}
|
||||
|
||||
fn map_error(&self, req: &HttpRequest, err: JsonFieldError) -> Error {
|
||||
if let Some(err_handler) = self.err_handler.as_ref() {
|
||||
(*err_handler)(err, req)
|
||||
} else {
|
||||
err.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
|
||||
pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
|
||||
self.validate_content_type = validate_content_type;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JsonConfig {
|
||||
fn default() -> Self {
|
||||
DEFAULT_CONFIG
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{collections::HashMap, io::Cursor};
|
||||
|
||||
use actix_multipart_rfc7578::client::multipart;
|
||||
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
|
||||
|
||||
use crate::form::{
|
||||
json::{Json, JsonConfig},
|
||||
tests::send_form,
|
||||
MultipartForm,
|
||||
};
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct JsonForm {
|
||||
json: Json<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
async fn test_json_route(form: MultipartForm<JsonForm>) -> impl Responder {
|
||||
let mut expected = HashMap::new();
|
||||
expected.insert("key1".to_owned(), "value1".to_owned());
|
||||
expected.insert("key2".to_owned(), "value2".to_owned());
|
||||
assert_eq!(&*form.json, &expected);
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_json_without_content_type() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/", web::post().to(test_json_route))
|
||||
.app_data(JsonConfig::default().validate_content_type(false))
|
||||
});
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("json", "{\"key1\": \"value1\", \"key2\": \"value2\"}");
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_content_type_validation() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/", web::post().to(test_json_route))
|
||||
.app_data(JsonConfig::default().validate_content_type(true))
|
||||
});
|
||||
|
||||
// Deny because wrong content type
|
||||
let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}");
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_OCTET_STREAM);
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Allow because correct content type
|
||||
let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}");
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_JSON);
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
744
actix-multipart/src/form/mod.rs
Normal file
744
actix-multipart/src/form/mod.rs
Normal file
@ -0,0 +1,744 @@
|
||||
//! Process and extract typed data from a multipart stream.
|
||||
|
||||
use std::{
|
||||
any::Any,
|
||||
collections::HashMap,
|
||||
future::{ready, Future},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest};
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
use futures_util::{TryFutureExt as _, TryStreamExt as _};
|
||||
|
||||
use crate::{Field, Multipart, MultipartError};
|
||||
|
||||
pub mod bytes;
|
||||
pub mod json;
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))]
|
||||
#[cfg(feature = "tempfile")]
|
||||
pub mod tempfile;
|
||||
pub mod text;
|
||||
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "derive")))]
|
||||
#[cfg(feature = "derive")]
|
||||
pub use actix_multipart_derive::MultipartForm;
|
||||
|
||||
type FieldErrorHandler<T> = Option<Arc<dyn Fn(T, &HttpRequest) -> Error + Send + Sync>>;
|
||||
|
||||
/// Trait that data types to be used in a multipart form struct should implement.
|
||||
///
|
||||
/// It represents an asynchronous handler that processes a multipart field to produce `Self`.
|
||||
pub trait FieldReader<'t>: Sized + Any {
|
||||
/// Future that resolves to a `Self`.
|
||||
type Future: Future<Output = Result<Self, MultipartError>>;
|
||||
|
||||
/// The form will call this function to handle the field.
|
||||
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future;
|
||||
}
|
||||
|
||||
/// Used to accumulate the state of the loaded fields.
|
||||
#[doc(hidden)]
|
||||
#[derive(Default, Deref, DerefMut)]
|
||||
pub struct State(pub HashMap<String, Box<dyn Any>>);
|
||||
|
||||
/// Trait that the field collection types implement, i.e. `Vec<T>`, `Option<T>`, or `T` itself.
|
||||
#[doc(hidden)]
|
||||
pub trait FieldGroupReader<'t>: Sized + Any {
|
||||
type Future: Future<Output = Result<(), MultipartError>>;
|
||||
|
||||
/// The form will call this function for each matching field.
|
||||
fn handle_field(
|
||||
req: &'t HttpRequest,
|
||||
field: Field,
|
||||
limits: &'t mut Limits,
|
||||
state: &'t mut State,
|
||||
duplicate_field: DuplicateField,
|
||||
) -> Self::Future;
|
||||
|
||||
/// Construct `Self` from the group of processed fields.
|
||||
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError>;
|
||||
}
|
||||
|
||||
impl<'t, T> FieldGroupReader<'t> for Option<T>
|
||||
where
|
||||
T: FieldReader<'t>,
|
||||
{
|
||||
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
|
||||
|
||||
fn handle_field(
|
||||
req: &'t HttpRequest,
|
||||
field: Field,
|
||||
limits: &'t mut Limits,
|
||||
state: &'t mut State,
|
||||
duplicate_field: DuplicateField,
|
||||
) -> Self::Future {
|
||||
if state.contains_key(field.name()) {
|
||||
match duplicate_field {
|
||||
DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
|
||||
|
||||
DuplicateField::Deny => {
|
||||
return Box::pin(ready(Err(MultipartError::DuplicateField(
|
||||
field.name().to_owned(),
|
||||
))))
|
||||
}
|
||||
|
||||
DuplicateField::Replace => {}
|
||||
}
|
||||
}
|
||||
|
||||
Box::pin(async move {
|
||||
let field_name = field.name().to_owned();
|
||||
let t = T::read_field(req, field, limits).await?;
|
||||
state.insert(field_name, Box::new(t));
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
|
||||
Ok(state.remove(name).map(|m| *m.downcast::<T>().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t, T> FieldGroupReader<'t> for Vec<T>
|
||||
where
|
||||
T: FieldReader<'t>,
|
||||
{
|
||||
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
|
||||
|
||||
fn handle_field(
|
||||
req: &'t HttpRequest,
|
||||
field: Field,
|
||||
limits: &'t mut Limits,
|
||||
state: &'t mut State,
|
||||
_duplicate_field: DuplicateField,
|
||||
) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
// Note: Vec GroupReader always allows duplicates
|
||||
|
||||
let field_name = field.name().to_owned();
|
||||
|
||||
let vec = state
|
||||
.entry(field_name)
|
||||
.or_insert_with(|| Box::<Vec<T>>::default())
|
||||
.downcast_mut::<Vec<T>>()
|
||||
.unwrap();
|
||||
|
||||
let item = T::read_field(req, field, limits).await?;
|
||||
vec.push(item);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
|
||||
Ok(state
|
||||
.remove(name)
|
||||
.map(|m| *m.downcast::<Vec<T>>().unwrap())
|
||||
.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t, T> FieldGroupReader<'t> for T
|
||||
where
|
||||
T: FieldReader<'t>,
|
||||
{
|
||||
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
|
||||
|
||||
fn handle_field(
|
||||
req: &'t HttpRequest,
|
||||
field: Field,
|
||||
limits: &'t mut Limits,
|
||||
state: &'t mut State,
|
||||
duplicate_field: DuplicateField,
|
||||
) -> Self::Future {
|
||||
if state.contains_key(field.name()) {
|
||||
match duplicate_field {
|
||||
DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
|
||||
|
||||
DuplicateField::Deny => {
|
||||
return Box::pin(ready(Err(MultipartError::DuplicateField(
|
||||
field.name().to_owned(),
|
||||
))))
|
||||
}
|
||||
|
||||
DuplicateField::Replace => {}
|
||||
}
|
||||
}
|
||||
|
||||
Box::pin(async move {
|
||||
let field_name = field.name().to_owned();
|
||||
let t = T::read_field(req, field, limits).await?;
|
||||
state.insert(field_name, Box::new(t));
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
|
||||
state
|
||||
.remove(name)
|
||||
.map(|m| *m.downcast::<T>().unwrap())
|
||||
.ok_or_else(|| MultipartError::MissingField(name.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that allows a type to be used in the [`struct@MultipartForm`] extractor.
|
||||
///
|
||||
/// You should use the [`macro@MultipartForm`] macro to derive this for your struct.
|
||||
pub trait MultipartCollect: Sized {
|
||||
/// An optional limit in bytes to be applied a given field name. Note this limit will be shared
|
||||
/// across all fields sharing the same name.
|
||||
fn limit(field_name: &str) -> Option<usize>;
|
||||
|
||||
/// The extractor will call this function for each incoming field, the state can be updated
|
||||
/// with the processed field data.
|
||||
fn handle_field<'t>(
|
||||
req: &'t HttpRequest,
|
||||
field: Field,
|
||||
limits: &'t mut Limits,
|
||||
state: &'t mut State,
|
||||
) -> LocalBoxFuture<'t, Result<(), MultipartError>>;
|
||||
|
||||
/// Once all the fields have been processed and stored in the state, this is called
|
||||
/// to convert into the struct representation.
|
||||
fn from_state(state: State) -> Result<Self, MultipartError>;
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub enum DuplicateField {
|
||||
/// Additional fields are not processed.
|
||||
Ignore,
|
||||
|
||||
/// An error will be raised.
|
||||
Deny,
|
||||
|
||||
/// All fields will be processed, the last one will replace all previous.
|
||||
Replace,
|
||||
}
|
||||
|
||||
/// Used to keep track of the remaining limits for the form and current field.
|
||||
pub struct Limits {
|
||||
pub total_limit_remaining: usize,
|
||||
pub memory_limit_remaining: usize,
|
||||
pub field_limit_remaining: Option<usize>,
|
||||
}
|
||||
|
||||
impl Limits {
|
||||
pub fn new(total_limit: usize, memory_limit: usize) -> Self {
|
||||
Self {
|
||||
total_limit_remaining: total_limit,
|
||||
memory_limit_remaining: memory_limit,
|
||||
field_limit_remaining: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// This function should be called within a [`FieldReader`] when reading each chunk of a field
|
||||
/// to ensure that the form limits are not exceeded.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `bytes` - The number of bytes being read from this chunk
|
||||
/// * `in_memory` - Whether to consume from the memory limits
|
||||
pub fn try_consume_limits(
|
||||
&mut self,
|
||||
bytes: usize,
|
||||
in_memory: bool,
|
||||
) -> Result<(), MultipartError> {
|
||||
self.total_limit_remaining = self
|
||||
.total_limit_remaining
|
||||
.checked_sub(bytes)
|
||||
.ok_or(MultipartError::Payload(PayloadError::Overflow))?;
|
||||
|
||||
if in_memory {
|
||||
self.memory_limit_remaining = self
|
||||
.memory_limit_remaining
|
||||
.checked_sub(bytes)
|
||||
.ok_or(MultipartError::Payload(PayloadError::Overflow))?;
|
||||
}
|
||||
|
||||
if let Some(field_limit) = self.field_limit_remaining {
|
||||
self.field_limit_remaining = Some(
|
||||
field_limit
|
||||
.checked_sub(bytes)
|
||||
.ok_or(MultipartError::Payload(PayloadError::Overflow))?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Typed `multipart/form-data` extractor.
|
||||
///
|
||||
/// To extract typed data from a multipart stream, the inner type `T` must implement the
|
||||
/// [`MultipartCollect`] trait. You should use the [`macro@MultipartForm`] macro to derive this
|
||||
/// for your struct.
|
||||
///
|
||||
/// Add a [`MultipartFormConfig`] to your app data to configure extraction.
|
||||
#[derive(Deref, DerefMut)]
|
||||
pub struct MultipartForm<T: MultipartCollect>(pub T);
|
||||
|
||||
impl<T: MultipartCollect> MultipartForm<T> {
|
||||
/// Unwrap into inner `T` value.
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromRequest for MultipartForm<T>
|
||||
where
|
||||
T: MultipartCollect,
|
||||
{
|
||||
type Error = Error;
|
||||
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
|
||||
|
||||
#[inline]
|
||||
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
|
||||
let mut payload = Multipart::new(req.headers(), payload.take());
|
||||
|
||||
let config = MultipartFormConfig::from_req(req);
|
||||
let mut limits = Limits::new(config.total_limit, config.memory_limit);
|
||||
|
||||
let req = req.clone();
|
||||
let req2 = req.clone();
|
||||
let err_handler = config.err_handler.clone();
|
||||
|
||||
Box::pin(
|
||||
async move {
|
||||
let mut state = State::default();
|
||||
// We need to ensure field limits are shared for all instances of this field name
|
||||
let mut field_limits = HashMap::<String, Option<usize>>::new();
|
||||
|
||||
while let Some(field) = payload.try_next().await? {
|
||||
// Retrieve the limit for this field
|
||||
let entry = field_limits
|
||||
.entry(field.name().to_owned())
|
||||
.or_insert_with(|| T::limit(field.name()));
|
||||
limits.field_limit_remaining = entry.to_owned();
|
||||
|
||||
T::handle_field(&req, field, &mut limits, &mut state).await?;
|
||||
|
||||
// Update the stored limit
|
||||
*entry = limits.field_limit_remaining;
|
||||
}
|
||||
let inner = T::from_state(state)?;
|
||||
Ok(MultipartForm(inner))
|
||||
}
|
||||
.map_err(move |err| {
|
||||
if let Some(handler) = err_handler {
|
||||
(*handler)(err, &req2)
|
||||
} else {
|
||||
err.into()
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type MultipartFormErrorHandler =
|
||||
Option<Arc<dyn Fn(MultipartError, &HttpRequest) -> Error + Send + Sync>>;
|
||||
|
||||
/// [`struct@MultipartForm`] extractor configuration.
|
||||
///
|
||||
/// Add to your app data to have it picked up by [`struct@MultipartForm`] extractors.
|
||||
#[derive(Clone)]
|
||||
pub struct MultipartFormConfig {
|
||||
total_limit: usize,
|
||||
memory_limit: usize,
|
||||
err_handler: MultipartFormErrorHandler,
|
||||
}
|
||||
|
||||
impl MultipartFormConfig {
|
||||
/// Sets maximum accepted payload size for the entire form. By default this limit is 50MiB.
|
||||
pub fn total_limit(mut self, total_limit: usize) -> Self {
|
||||
self.total_limit = total_limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets maximum accepted data that will be read into memory. By default this limit is 2MiB.
|
||||
pub fn memory_limit(mut self, memory_limit: usize) -> Self {
|
||||
self.memory_limit = memory_limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets custom error handler.
|
||||
pub fn error_handler<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(MultipartError, &HttpRequest) -> Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err_handler = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
|
||||
/// back to the default payload config.
|
||||
fn from_req(req: &HttpRequest) -> &Self {
|
||||
req.app_data::<Self>()
|
||||
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
|
||||
.unwrap_or(&DEFAULT_CONFIG)
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_CONFIG: MultipartFormConfig = MultipartFormConfig {
|
||||
total_limit: 52_428_800, // 50 MiB
|
||||
memory_limit: 2_097_152, // 2 MiB
|
||||
err_handler: None,
|
||||
};
|
||||
|
||||
impl Default for MultipartFormConfig {
|
||||
fn default() -> Self {
|
||||
DEFAULT_CONFIG
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use actix_http::encoding::Decoder;
|
||||
use actix_multipart_rfc7578::client::multipart;
|
||||
use actix_test::TestServer;
|
||||
use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder};
|
||||
use awc::{Client, ClientResponse};
|
||||
|
||||
use super::MultipartForm;
|
||||
use crate::form::{bytes::Bytes, tempfile::TempFile, text::Text, MultipartFormConfig};
|
||||
|
||||
pub async fn send_form(
|
||||
srv: &TestServer,
|
||||
form: multipart::Form<'static>,
|
||||
uri: &'static str,
|
||||
) -> ClientResponse<Decoder<Payload>> {
|
||||
Client::default()
|
||||
.post(srv.url(uri))
|
||||
.content_type(form.content_type())
|
||||
.send_body(multipart::Body::from(form))
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Test `Option` fields.
|
||||
#[derive(MultipartForm)]
|
||||
struct TestOptions {
|
||||
field1: Option<Text<String>>,
|
||||
field2: Option<Text<String>>,
|
||||
}
|
||||
|
||||
async fn test_options_route(form: MultipartForm<TestOptions>) -> impl Responder {
|
||||
assert!(form.field1.is_some());
|
||||
assert!(form.field2.is_none());
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_options() {
|
||||
let srv =
|
||||
actix_test::start(|| App::new().route("/", web::post().to(test_options_route)));
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field1", "value");
|
||||
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test `Vec` fields.
|
||||
#[derive(MultipartForm)]
|
||||
struct TestVec {
|
||||
list1: Vec<Text<String>>,
|
||||
list2: Vec<Text<String>>,
|
||||
}
|
||||
|
||||
async fn test_vec_route(form: MultipartForm<TestVec>) -> impl Responder {
|
||||
let form = form.into_inner();
|
||||
let strings = form
|
||||
.list1
|
||||
.into_iter()
|
||||
.map(|s| s.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(strings, vec!["value1", "value2", "value3"]);
|
||||
assert_eq!(form.list2.len(), 0);
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_vec() {
|
||||
let srv = actix_test::start(|| App::new().route("/", web::post().to(test_vec_route)));
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("list1", "value1");
|
||||
form.add_text("list1", "value2");
|
||||
form.add_text("list1", "value3");
|
||||
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test the `rename` field attribute.
|
||||
#[derive(MultipartForm)]
|
||||
struct TestFieldRenaming {
|
||||
#[multipart(rename = "renamed")]
|
||||
field1: Text<String>,
|
||||
#[multipart(rename = "field1")]
|
||||
field2: Text<String>,
|
||||
field3: Text<String>,
|
||||
}
|
||||
|
||||
async fn test_field_renaming_route(
|
||||
form: MultipartForm<TestFieldRenaming>,
|
||||
) -> impl Responder {
|
||||
assert_eq!(&*form.field1, "renamed");
|
||||
assert_eq!(&*form.field2, "field1");
|
||||
assert_eq!(&*form.field3, "field3");
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_field_renaming() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new().route("/", web::post().to(test_field_renaming_route))
|
||||
});
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("renamed", "renamed");
|
||||
form.add_text("field1", "field1");
|
||||
form.add_text("field3", "field3");
|
||||
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test the `deny_unknown_fields` struct attribute.
|
||||
#[derive(MultipartForm)]
|
||||
#[multipart(deny_unknown_fields)]
|
||||
struct TestDenyUnknown {}
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct TestAllowUnknown {}
|
||||
|
||||
async fn test_deny_unknown_route(_: MultipartForm<TestDenyUnknown>) -> impl Responder {
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
async fn test_allow_unknown_route(_: MultipartForm<TestAllowUnknown>) -> impl Responder {
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_deny_unknown() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/deny", web::post().to(test_deny_unknown_route))
|
||||
.route("/allow", web::post().to(test_allow_unknown_route))
|
||||
});
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("unknown", "value");
|
||||
let response = send_form(&srv, form, "/deny").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("unknown", "value");
|
||||
let response = send_form(&srv, form, "/allow").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test the `duplicate_field` struct attribute.
|
||||
#[derive(MultipartForm)]
|
||||
#[multipart(duplicate_field = "deny")]
|
||||
struct TestDuplicateDeny {
|
||||
_field: Text<String>,
|
||||
}
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
#[multipart(duplicate_field = "replace")]
|
||||
struct TestDuplicateReplace {
|
||||
field: Text<String>,
|
||||
}
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
#[multipart(duplicate_field = "ignore")]
|
||||
struct TestDuplicateIgnore {
|
||||
field: Text<String>,
|
||||
}
|
||||
|
||||
async fn test_duplicate_deny_route(_: MultipartForm<TestDuplicateDeny>) -> impl Responder {
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
async fn test_duplicate_replace_route(
|
||||
form: MultipartForm<TestDuplicateReplace>,
|
||||
) -> impl Responder {
|
||||
assert_eq!(&*form.field, "second_value");
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
async fn test_duplicate_ignore_route(
|
||||
form: MultipartForm<TestDuplicateIgnore>,
|
||||
) -> impl Responder {
|
||||
assert_eq!(&*form.field, "first_value");
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_duplicate_field() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/deny", web::post().to(test_duplicate_deny_route))
|
||||
.route("/replace", web::post().to(test_duplicate_replace_route))
|
||||
.route("/ignore", web::post().to(test_duplicate_ignore_route))
|
||||
});
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("_field", "first_value");
|
||||
form.add_text("_field", "second_value");
|
||||
let response = send_form(&srv, form, "/deny").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "first_value");
|
||||
form.add_text("field", "second_value");
|
||||
let response = send_form(&srv, form, "/replace").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "first_value");
|
||||
form.add_text("field", "second_value");
|
||||
let response = send_form(&srv, form, "/ignore").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test the Limits.
|
||||
#[derive(MultipartForm)]
|
||||
struct TestMemoryUploadLimits {
|
||||
field: Bytes,
|
||||
}
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct TestFileUploadLimits {
|
||||
field: TempFile,
|
||||
}
|
||||
|
||||
async fn test_upload_limits_memory(
|
||||
form: MultipartForm<TestMemoryUploadLimits>,
|
||||
) -> impl Responder {
|
||||
assert!(!form.field.data.is_empty());
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
async fn test_upload_limits_file(
|
||||
form: MultipartForm<TestFileUploadLimits>,
|
||||
) -> impl Responder {
|
||||
assert!(form.field.size > 0);
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_memory_limits() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/text", web::post().to(test_upload_limits_memory))
|
||||
.route("/file", web::post().to(test_upload_limits_file))
|
||||
.app_data(
|
||||
MultipartFormConfig::default()
|
||||
.memory_limit(20)
|
||||
.total_limit(usize::MAX),
|
||||
)
|
||||
});
|
||||
|
||||
// Exceeds the 20 byte memory limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/text").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Memory limit should not apply when the data is being streamed to disk
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/file").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_total_limit() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/text", web::post().to(test_upload_limits_memory))
|
||||
.route("/file", web::post().to(test_upload_limits_file))
|
||||
.app_data(
|
||||
MultipartFormConfig::default()
|
||||
.memory_limit(usize::MAX)
|
||||
.total_limit(20),
|
||||
)
|
||||
});
|
||||
|
||||
// Within the 20 byte limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "7 bytes");
|
||||
let response = send_form(&srv, form, "/text").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Exceeds the 20 byte overall limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/text").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Exceeds the 20 byte overall limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/file").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct TestFieldLevelLimits {
|
||||
#[multipart(limit = "30B")]
|
||||
field: Vec<Bytes>,
|
||||
}
|
||||
|
||||
async fn test_field_level_limits_route(
|
||||
form: MultipartForm<TestFieldLevelLimits>,
|
||||
) -> impl Responder {
|
||||
assert!(!form.field.is_empty());
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_field_level_limits() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/", web::post().to(test_field_level_limits_route))
|
||||
.app_data(
|
||||
MultipartFormConfig::default()
|
||||
.memory_limit(usize::MAX)
|
||||
.total_limit(usize::MAX),
|
||||
)
|
||||
});
|
||||
|
||||
// Within the 30 byte limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Exceeds the the 30 byte limit
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is more than 30 bytes long");
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Total of values (14 bytes) is within 30 byte limit for "field"
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "7 bytes");
|
||||
form.add_text("field", "7 bytes");
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Total of values exceeds 30 byte limit for "field"
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
form.add_text("field", "this string is 28 bytes long");
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
206
actix-multipart/src/form/tempfile.rs
Normal file
206
actix-multipart/src/form/tempfile.rs
Normal file
@ -0,0 +1,206 @@
|
||||
//! Writes a field to a temporary file on disk.
|
||||
|
||||
use std::{
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
|
||||
use derive_more::{Display, Error};
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
use futures_util::TryStreamExt as _;
|
||||
use mime::Mime;
|
||||
use tempfile_dep::NamedTempFile;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use super::FieldErrorHandler;
|
||||
use crate::{
|
||||
form::{FieldReader, Limits},
|
||||
Field, MultipartError,
|
||||
};
|
||||
|
||||
/// Write the field to a temporary file on disk.
|
||||
#[derive(Debug)]
|
||||
pub struct TempFile {
|
||||
/// The temporary file on disk.
|
||||
pub file: NamedTempFile,
|
||||
|
||||
/// The value of the `content-type` header.
|
||||
pub content_type: Option<Mime>,
|
||||
|
||||
/// The `filename` value in the `content-disposition` header.
|
||||
pub file_name: Option<String>,
|
||||
|
||||
/// The size in bytes of the file.
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl<'t> FieldReader<'t> for TempFile {
|
||||
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
|
||||
|
||||
fn read_field(
|
||||
req: &'t HttpRequest,
|
||||
mut field: Field,
|
||||
limits: &'t mut Limits,
|
||||
) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let config = TempFileConfig::from_req(req);
|
||||
let field_name = field.name().to_owned();
|
||||
let mut size = 0;
|
||||
|
||||
let file = config.create_tempfile().map_err(|err| {
|
||||
config.map_error(req, &field_name, TempFileError::FileIo(err))
|
||||
})?;
|
||||
|
||||
let mut file_async = tokio::fs::File::from_std(file.reopen().map_err(|err| {
|
||||
config.map_error(req, &field_name, TempFileError::FileIo(err))
|
||||
})?);
|
||||
|
||||
while let Some(chunk) = field.try_next().await? {
|
||||
limits.try_consume_limits(chunk.len(), false)?;
|
||||
size += chunk.len();
|
||||
file_async.write_all(chunk.as_ref()).await.map_err(|err| {
|
||||
config.map_error(req, &field_name, TempFileError::FileIo(err))
|
||||
})?;
|
||||
}
|
||||
|
||||
file_async.flush().await.map_err(|err| {
|
||||
config.map_error(req, &field_name, TempFileError::FileIo(err))
|
||||
})?;
|
||||
|
||||
Ok(TempFile {
|
||||
file,
|
||||
content_type: field.content_type().map(ToOwned::to_owned),
|
||||
file_name: field
|
||||
.content_disposition()
|
||||
.get_filename()
|
||||
.map(str::to_owned),
|
||||
size,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum TempFileError {
|
||||
/// File I/O Error
|
||||
#[display(fmt = "File I/O error: {}", _0)]
|
||||
FileIo(std::io::Error),
|
||||
}
|
||||
|
||||
impl ResponseError for TempFileError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the [`TempFile`] field reader.
|
||||
#[derive(Clone)]
|
||||
pub struct TempFileConfig {
|
||||
err_handler: FieldErrorHandler<TempFileError>,
|
||||
directory: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl TempFileConfig {
|
||||
fn create_tempfile(&self) -> io::Result<NamedTempFile> {
|
||||
if let Some(ref dir) = self.directory {
|
||||
NamedTempFile::new_in(dir)
|
||||
} else {
|
||||
NamedTempFile::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TempFileConfig {
|
||||
/// Sets custom error handler.
|
||||
pub fn error_handler<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(TempFileError, &HttpRequest) -> Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err_handler = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
|
||||
/// back to the default payload config.
|
||||
fn from_req(req: &HttpRequest) -> &Self {
|
||||
req.app_data::<Self>()
|
||||
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
|
||||
.unwrap_or(&DEFAULT_CONFIG)
|
||||
}
|
||||
|
||||
fn map_error(
|
||||
&self,
|
||||
req: &HttpRequest,
|
||||
field_name: &str,
|
||||
err: TempFileError,
|
||||
) -> MultipartError {
|
||||
let source = if let Some(ref err_handler) = self.err_handler {
|
||||
(err_handler)(err, req)
|
||||
} else {
|
||||
err.into()
|
||||
};
|
||||
|
||||
MultipartError::Field {
|
||||
field_name: field_name.to_owned(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the directory that temp files will be created in.
|
||||
///
|
||||
/// The default temporary file location is platform dependent.
|
||||
pub fn directory(mut self, dir: impl AsRef<Path>) -> Self {
|
||||
self.directory = Some(dir.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_CONFIG: TempFileConfig = TempFileConfig {
|
||||
err_handler: None,
|
||||
directory: None,
|
||||
};
|
||||
|
||||
impl Default for TempFileConfig {
|
||||
fn default() -> Self {
|
||||
DEFAULT_CONFIG
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::{Cursor, Read};
|
||||
|
||||
use actix_multipart_rfc7578::client::multipart;
|
||||
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
|
||||
|
||||
use crate::form::{tempfile::TempFile, tests::send_form, MultipartForm};
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct FileForm {
|
||||
file: TempFile,
|
||||
}
|
||||
|
||||
async fn test_file_route(form: MultipartForm<FileForm>) -> impl Responder {
|
||||
let mut form = form.into_inner();
|
||||
let mut contents = String::new();
|
||||
form.file.file.read_to_string(&mut contents).unwrap();
|
||||
assert_eq!(contents, "Hello, world!");
|
||||
assert_eq!(form.file.file_name.unwrap(), "testfile.txt");
|
||||
assert_eq!(form.file.content_type.unwrap(), mime::TEXT_PLAIN);
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_file_upload() {
|
||||
let srv = actix_test::start(|| App::new().route("/", web::post().to(test_file_route)));
|
||||
|
||||
let mut form = multipart::Form::default();
|
||||
let bytes = Cursor::new("Hello, world!");
|
||||
form.add_reader_file_with_mime("file", bytes, "testfile.txt", mime::TEXT_PLAIN);
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
196
actix-multipart/src/form/text.rs
Normal file
196
actix-multipart/src/form/text.rs
Normal file
@ -0,0 +1,196 @@
|
||||
//! Deserializes a field from plain text.
|
||||
|
||||
use std::{str, sync::Arc};
|
||||
|
||||
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
|
||||
use derive_more::{Deref, DerefMut, Display, Error};
|
||||
use futures_core::future::LocalBoxFuture;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use super::FieldErrorHandler;
|
||||
use crate::{
|
||||
form::{bytes::Bytes, FieldReader, Limits},
|
||||
Field, MultipartError,
|
||||
};
|
||||
|
||||
/// Deserialize from plain text.
|
||||
///
|
||||
/// Internally this uses [`serde_plain`] for deserialization, which supports primitive types
|
||||
/// including strings, numbers, and simple enums.
|
||||
#[derive(Debug, Deref, DerefMut)]
|
||||
pub struct Text<T: DeserializeOwned>(pub T);
|
||||
|
||||
impl<T: DeserializeOwned> Text<T> {
|
||||
/// Unwraps into inner value.
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t, T> FieldReader<'t> for Text<T>
|
||||
where
|
||||
T: DeserializeOwned + 'static,
|
||||
{
|
||||
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
|
||||
|
||||
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let config = TextConfig::from_req(req);
|
||||
let field_name = field.name().to_owned();
|
||||
|
||||
if config.validate_content_type {
|
||||
let valid = if let Some(mime) = field.content_type() {
|
||||
mime.subtype() == mime::PLAIN || mime.suffix() == Some(mime::PLAIN)
|
||||
} else {
|
||||
// https://datatracker.ietf.org/doc/html/rfc7578#section-4.4
|
||||
// content type defaults to text/plain, so None should be considered valid
|
||||
true
|
||||
};
|
||||
|
||||
if !valid {
|
||||
return Err(MultipartError::Field {
|
||||
field_name,
|
||||
source: config.map_error(req, TextError::ContentType),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = Bytes::read_field(req, field, limits).await?;
|
||||
|
||||
let text = str::from_utf8(&bytes.data).map_err(|err| MultipartError::Field {
|
||||
field_name: field_name.clone(),
|
||||
source: config.map_error(req, TextError::Utf8Error(err)),
|
||||
})?;
|
||||
|
||||
Ok(Text(serde_plain::from_str(text).map_err(|err| {
|
||||
MultipartError::Field {
|
||||
field_name,
|
||||
source: config.map_error(req, TextError::Deserialize(err)),
|
||||
}
|
||||
})?))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum TextError {
|
||||
/// UTF-8 decoding error.
|
||||
#[display(fmt = "UTF-8 decoding error: {}", _0)]
|
||||
Utf8Error(str::Utf8Error),
|
||||
|
||||
/// Deserialize error.
|
||||
#[display(fmt = "Plain text deserialize error: {}", _0)]
|
||||
Deserialize(serde_plain::Error),
|
||||
|
||||
/// Content type error.
|
||||
#[display(fmt = "Content type error")]
|
||||
ContentType,
|
||||
}
|
||||
|
||||
impl ResponseError for TextError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
StatusCode::BAD_REQUEST
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the [`Text`] field reader.
|
||||
#[derive(Clone)]
|
||||
pub struct TextConfig {
|
||||
err_handler: FieldErrorHandler<TextError>,
|
||||
validate_content_type: bool,
|
||||
}
|
||||
|
||||
impl TextConfig {
|
||||
/// Sets custom error handler.
|
||||
pub fn error_handler<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(TextError, &HttpRequest) -> Error + Send + Sync + 'static,
|
||||
{
|
||||
self.err_handler = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
|
||||
/// back to the default payload config.
|
||||
fn from_req(req: &HttpRequest) -> &Self {
|
||||
req.app_data::<Self>()
|
||||
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
|
||||
.unwrap_or(&DEFAULT_CONFIG)
|
||||
}
|
||||
|
||||
fn map_error(&self, req: &HttpRequest, err: TextError) -> Error {
|
||||
if let Some(ref err_handler) = self.err_handler {
|
||||
(err_handler)(err, req)
|
||||
} else {
|
||||
err.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
|
||||
///
|
||||
/// Note that an empty `Content-Type` is also accepted, as the multipart specification defines
|
||||
/// `text/plain` as the default for text fields.
|
||||
pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
|
||||
self.validate_content_type = validate_content_type;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_CONFIG: TextConfig = TextConfig {
|
||||
err_handler: None,
|
||||
validate_content_type: true,
|
||||
};
|
||||
|
||||
impl Default for TextConfig {
|
||||
fn default() -> Self {
|
||||
DEFAULT_CONFIG
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Cursor;
|
||||
|
||||
use actix_multipart_rfc7578::client::multipart;
|
||||
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
|
||||
|
||||
use crate::form::{
|
||||
tests::send_form,
|
||||
text::{Text, TextConfig},
|
||||
MultipartForm,
|
||||
};
|
||||
|
||||
#[derive(MultipartForm)]
|
||||
struct TextForm {
|
||||
number: Text<i32>,
|
||||
}
|
||||
|
||||
async fn test_text_route(form: MultipartForm<TextForm>) -> impl Responder {
|
||||
assert_eq!(*form.number, 1025);
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_content_type_validation() {
|
||||
let srv = actix_test::start(|| {
|
||||
App::new()
|
||||
.route("/", web::post().to(test_text_route))
|
||||
.app_data(TextConfig::default().validate_content_type(true))
|
||||
});
|
||||
|
||||
// Deny because wrong content type
|
||||
let bytes = Cursor::new("1025");
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_reader_file_with_mime("number", bytes, "", mime::APPLICATION_OCTET_STREAM);
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Allow because correct content type
|
||||
let bytes = Cursor::new("1025");
|
||||
let mut form = multipart::Form::default();
|
||||
form.add_reader_file_with_mime("number", bytes, "", mime::TEXT_PLAIN);
|
||||
let response = send_form(&srv, form, "/").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
@ -3,10 +3,17 @@
|
||||
#![deny(rust_2018_idioms, nonstandard_style)]
|
||||
#![warn(future_incompatible)]
|
||||
#![allow(clippy::borrow_interior_mutable_const, clippy::uninlined_format_args)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
// This allows us to use the actix_multipart_derive within this crate's tests
|
||||
#[cfg(test)]
|
||||
extern crate self as actix_multipart;
|
||||
|
||||
mod error;
|
||||
mod extractor;
|
||||
mod server;
|
||||
|
||||
pub mod form;
|
||||
|
||||
pub use self::error::MultipartError;
|
||||
pub use self::server::{Field, Multipart};
|
||||
|
@ -270,7 +270,9 @@ impl InnerMultipart {
|
||||
match field.borrow_mut().poll(safety) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Some(Ok(_))) => continue,
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
return Poll::Ready(Some(Err(err)))
|
||||
}
|
||||
Poll::Ready(None) => true,
|
||||
}
|
||||
}
|
||||
@ -658,7 +660,7 @@ impl InnerField {
|
||||
match res {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(None) => self.eof = true,
|
||||
}
|
||||
}
|
||||
@ -673,7 +675,7 @@ impl InnerField {
|
||||
}
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Err(e) => Poll::Ready(Some(Err(e))),
|
||||
Err(err) => Poll::Ready(Some(Err(err))),
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
@ -794,7 +796,7 @@ impl PayloadBuffer {
|
||||
loop {
|
||||
match Pin::new(&mut self.stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
|
||||
Poll::Ready(Some(Err(e))) => return Err(e),
|
||||
Poll::Ready(Some(Err(err))) => return Err(err),
|
||||
Poll::Ready(None) => {
|
||||
self.eof = true;
|
||||
return Ok(());
|
||||
@ -860,19 +862,22 @@ impl PayloadBuffer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_http::h1::Payload;
|
||||
use actix_web::http::header::{DispositionParam, DispositionType};
|
||||
use actix_web::rt;
|
||||
use actix_web::test::TestRequest;
|
||||
use actix_web::FromRequest;
|
||||
use actix_http::h1;
|
||||
use actix_web::{
|
||||
http::header::{DispositionParam, DispositionType},
|
||||
rt,
|
||||
test::TestRequest,
|
||||
FromRequest,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future::lazy, StreamExt as _};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_boundary() {
|
||||
let headers = HeaderMap::new();
|
||||
@ -1119,7 +1124,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_basic() {
|
||||
let (_, payload) = Payload::create(false);
|
||||
let (_, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
|
||||
assert_eq!(payload.buf.len(), 0);
|
||||
@ -1129,7 +1134,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_eof() {
|
||||
let (mut sender, payload) = Payload::create(false);
|
||||
let (mut sender, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
|
||||
assert_eq!(None, payload.read_max(4).unwrap());
|
||||
@ -1145,7 +1150,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_err() {
|
||||
let (mut sender, payload) = Payload::create(false);
|
||||
let (mut sender, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
assert_eq!(None, payload.read_max(1).unwrap());
|
||||
sender.set_error(PayloadError::Incomplete(None));
|
||||
@ -1154,7 +1159,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_readmax() {
|
||||
let (mut sender, payload) = Payload::create(false);
|
||||
let (mut sender, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
|
||||
sender.feed_data(Bytes::from("line1"));
|
||||
@ -1171,7 +1176,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_readexactly() {
|
||||
let (mut sender, payload) = Payload::create(false);
|
||||
let (mut sender, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
|
||||
assert_eq!(None, payload.read_exact(2));
|
||||
@ -1189,7 +1194,7 @@ mod tests {
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_readuntil() {
|
||||
let (mut sender, payload) = Payload::create(false);
|
||||
let (mut sender, payload) = h1::Payload::create(false);
|
||||
let mut payload = PayloadBuffer::new(payload);
|
||||
|
||||
assert_eq!(None, payload.read_until(b"ne").unwrap());
|
||||
@ -1230,7 +1235,7 @@ mod tests {
|
||||
#[actix_rt::test]
|
||||
async fn test_multipart_payload_consumption() {
|
||||
// with sample payload and HttpRequest with no headers
|
||||
let (_, inner_payload) = Payload::create(false);
|
||||
let (_, inner_payload) = h1::Payload::create(false);
|
||||
let mut payload = actix_web::dev::Payload::from(inner_payload);
|
||||
let req = TestRequest::default().to_http_request();
|
||||
|
||||
|
Reference in New Issue
Block a user