diff --git a/src/types/payload.rs b/src/types/payload.rs index acb8b9a82..fd4d3e945 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -7,10 +7,15 @@ use std::task::{Context, Poll}; use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::HttpMessage; use bytes::{Bytes, BytesMut}; -use encoding_rs::UTF_8; +use encoding_rs::{Encoding, UTF_8}; use futures_core::stream::Stream; -use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; +use futures_util::{ + future::{ + err, ok, Either, ErrInto, FutureExt as _, LocalBoxFuture, Ready, + TryFutureExt as _, + }, + stream::StreamExt as _, +}; use mime::Mime; use crate::extract::FromRequest; @@ -135,10 +140,7 @@ impl FromRequest for Payload { impl FromRequest for Bytes { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either, Ready>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -151,7 +153,7 @@ impl FromRequest for Bytes { let limit = cfg.limit; let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left(async move { Ok(fut.await?) }.boxed_local()) + Either::Left(fut.err_into()) } } @@ -185,10 +187,7 @@ impl FromRequest for Bytes { impl FromRequest for String { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -205,25 +204,40 @@ impl FromRequest for String { Err(e) => return Either::Right(err(e.into())), }; let limit = cfg.limit; - let fut = HttpMessageBody::new(req, payload).limit(limit); + let body_fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left( - async move { - let body = fut.await?; + Either::Left(StringExtractFut { body_fut, encoding }) + } +} - if encoding == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding - .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) - } - } - .boxed_local(), - ) +pub struct StringExtractFut { + body_fut: HttpMessageBody, + encoding: &'static Encoding, +} + +impl<'a> Future for StringExtractFut { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let encoding = self.encoding; + + Pin::new(&mut self.body_fut).poll(cx).map(|out| { + let body = out?; + bytes_to_string(body, encoding) + }) + } +} + +fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result { + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) } }