From ddaf8c3e4379d915ff9ba05bcf2d55227c21ab8e Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Wed, 5 May 2021 18:36:02 +0100 Subject: [PATCH] add associated error type to MessageBody (#2183) --- actix-http/CHANGES.md | 4 + actix-http/src/body/body.rs | 75 +++++++++++++++++-- actix-http/src/body/body_stream.rs | 4 +- actix-http/src/body/message_body.rs | 105 +++++++++++++++++++++++---- actix-http/src/body/mod.rs | 10 ++- actix-http/src/body/response_body.rs | 25 +++++-- actix-http/src/body/sized_stream.rs | 4 +- actix-http/src/builder.rs | 8 +- actix-http/src/client/connection.rs | 5 +- actix-http/src/client/h1proto.rs | 9 ++- actix-http/src/client/h2proto.rs | 30 +++++--- actix-http/src/encoding/encoder.rs | 85 +++++++++++++++++++--- actix-http/src/error.rs | 6 +- actix-http/src/h1/dispatcher.rs | 32 ++++++++ actix-http/src/h1/service.rs | 22 ++++++ actix-http/src/h1/utils.rs | 2 + actix-http/src/h2/dispatcher.rs | 7 ++ actix-http/src/h2/service.rs | 13 ++++ actix-http/src/response.rs | 6 +- actix-http/src/service.rs | 31 +++++++- actix-test/src/lib.rs | 2 + src/middleware/compat.rs | 6 +- src/middleware/logger.rs | 13 +++- src/response/response.rs | 6 +- src/server.rs | 1 + src/service.rs | 6 +- src/test.rs | 4 + 27 files changed, 447 insertions(+), 74 deletions(-) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index cc6330288..f398b1c92 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -2,11 +2,13 @@ ## Unreleased - 2021-xx-xx ### Added +* `BoxAnyBody`: a boxed message body with boxed errors. [#2183] * Re-export `http` crate's `Error` type as `error::HttpError`. [#2171] * Re-export `StatusCode`, `Method`, `Version` and `Uri` at the crate root. [#2171] * Re-export `ContentEncoding` and `ConnectionType` at the crate root. [#2171] ### Changed +* The `MessageBody` trait now has an associated `Error` type. [#2183] * `header` mod is now public. [#2171] * `uri` mod is now public. [#2171] * Update `language-tags` to `0.3`. @@ -14,8 +16,10 @@ ### Removed * Stop re-exporting `http` crate's `HeaderMap` types in addition to ours. [#2171] +* Down-casting for `MessageBody` types. [#2183] [#2171]: https://github.com/actix/actix-web/pull/2171 +[#2183]: https://github.com/actix/actix-web/pull/2183 [#2196]: https://github.com/actix/actix-web/pull/2196 diff --git a/actix-http/src/body/body.rs b/actix-http/src/body/body.rs index 4fe18338a..4c95bd31a 100644 --- a/actix-http/src/body/body.rs +++ b/actix-http/src/body/body.rs @@ -1,16 +1,17 @@ use std::{ borrow::Cow, + error::Error as StdError, fmt, mem, pin::Pin, task::{Context, Poll}, }; use bytes::{Bytes, BytesMut}; -use futures_core::Stream; +use futures_core::{ready, Stream}; use crate::error::Error; -use super::{BodySize, BodyStream, MessageBody, SizedStream}; +use super::{BodySize, BodyStream, MessageBody, MessageBodyMapErr, SizedStream}; /// Represents various types of HTTP message body. // #[deprecated(since = "4.0.0", note = "Use body types directly.")] @@ -25,7 +26,7 @@ pub enum Body { Bytes(Bytes), /// Generic message body. - Message(Pin>), + Message(BoxAnyBody), } impl Body { @@ -35,12 +36,18 @@ impl Body { } /// Create body from generic message body. - pub fn from_message(body: B) -> Body { - Body::Message(Box::pin(body)) + pub fn from_message(body: B) -> Body + where + B: MessageBody + 'static, + B::Error: Into>, + { + Self::Message(BoxAnyBody::from_body(body)) } } impl MessageBody for Body { + type Error = Error; + fn size(&self) -> BodySize { match self { Body::None => BodySize::None, @@ -53,7 +60,7 @@ impl MessageBody for Body { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match self.get_mut() { Body::None => Poll::Ready(None), Body::Empty => Poll::Ready(None), @@ -65,7 +72,13 @@ impl MessageBody for Body { Poll::Ready(Some(Ok(mem::take(bin)))) } } - Body::Message(body) => body.as_mut().poll_next(cx), + + // TODO: MSRV 1.51: poll_map_err + Body::Message(body) => match ready!(body.as_pin_mut().poll_next(cx)) { + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + }, } } } @@ -166,3 +179,51 @@ where Body::from_message(s) } } + +/// A boxed message body with boxed errors. +pub struct BoxAnyBody(Pin>>>); + +impl BoxAnyBody { + /// Boxes a `MessageBody` and any errors it generates. + pub fn from_body(body: B) -> Self + where + B: MessageBody + 'static, + B::Error: Into>, + { + let body = MessageBodyMapErr::new(body, Into::into); + Self(Box::pin(body)) + } + + /// Returns a mutable pinned reference to the inner message body type. + pub fn as_pin_mut( + &mut self, + ) -> Pin<&mut (dyn MessageBody>)> { + self.0.as_mut() + } +} + +impl fmt::Debug for BoxAnyBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("BoxAnyBody(dyn MessageBody)") + } +} + +impl MessageBody for BoxAnyBody { + type Error = Error; + + fn size(&self) -> BodySize { + self.0.size() + } + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + // TODO: MSRV 1.51: poll_map_err + match ready!(self.0.as_mut().poll_next(cx)) { + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + } + } +} diff --git a/actix-http/src/body/body_stream.rs b/actix-http/src/body/body_stream.rs index b81aeb4c1..ebe872022 100644 --- a/actix-http/src/body/body_stream.rs +++ b/actix-http/src/body/body_stream.rs @@ -36,6 +36,8 @@ where S: Stream>, E: Into, { + type Error = Error; + fn size(&self) -> BodySize { BodySize::Stream } @@ -48,7 +50,7 @@ where fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { loop { let stream = self.as_mut().project().stream; diff --git a/actix-http/src/body/message_body.rs b/actix-http/src/body/message_body.rs index 894a5fa98..2d2642ba7 100644 --- a/actix-http/src/body/message_body.rs +++ b/actix-http/src/body/message_body.rs @@ -1,12 +1,15 @@ //! [`MessageBody`] trait and foreign implementations. use std::{ + convert::Infallible, mem, pin::Pin, task::{Context, Poll}, }; use bytes::{Bytes, BytesMut}; +use futures_core::ready; +use pin_project_lite::pin_project; use crate::error::Error; @@ -14,6 +17,8 @@ use super::BodySize; /// An interface for response bodies. pub trait MessageBody { + type Error; + /// Body size hint. fn size(&self) -> BodySize; @@ -21,14 +26,12 @@ pub trait MessageBody { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>; - - downcast_get_type_id!(); + ) -> Poll>>; } -downcast!(MessageBody); - impl MessageBody for () { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Empty } @@ -36,12 +39,18 @@ impl MessageBody for () { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { Poll::Ready(None) } } -impl MessageBody for Box { +impl MessageBody for Box +where + B: MessageBody + Unpin, + B::Error: Into, +{ + type Error = B::Error; + fn size(&self) -> BodySize { self.as_ref().size() } @@ -49,12 +58,18 @@ impl MessageBody for Box { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { Pin::new(self.get_mut().as_mut()).poll_next(cx) } } -impl MessageBody for Pin> { +impl MessageBody for Pin> +where + B: MessageBody, + B::Error: Into, +{ + type Error = B::Error; + fn size(&self) -> BodySize { self.as_ref().size() } @@ -62,12 +77,14 @@ impl MessageBody for Pin> { fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { self.as_mut().poll_next(cx) } } impl MessageBody for Bytes { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Sized(self.len() as u64) } @@ -75,7 +92,7 @@ impl MessageBody for Bytes { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { if self.is_empty() { Poll::Ready(None) } else { @@ -85,6 +102,8 @@ impl MessageBody for Bytes { } impl MessageBody for BytesMut { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Sized(self.len() as u64) } @@ -92,7 +111,7 @@ impl MessageBody for BytesMut { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { if self.is_empty() { Poll::Ready(None) } else { @@ -102,6 +121,8 @@ impl MessageBody for BytesMut { } impl MessageBody for &'static str { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Sized(self.len() as u64) } @@ -109,7 +130,7 @@ impl MessageBody for &'static str { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { if self.is_empty() { Poll::Ready(None) } else { @@ -121,6 +142,8 @@ impl MessageBody for &'static str { } impl MessageBody for Vec { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Sized(self.len() as u64) } @@ -128,7 +151,7 @@ impl MessageBody for Vec { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { if self.is_empty() { Poll::Ready(None) } else { @@ -138,6 +161,8 @@ impl MessageBody for Vec { } impl MessageBody for String { + type Error = Infallible; + fn size(&self) -> BodySize { BodySize::Sized(self.len() as u64) } @@ -145,7 +170,7 @@ impl MessageBody for String { fn poll_next( self: Pin<&mut Self>, _: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { if self.is_empty() { Poll::Ready(None) } else { @@ -155,3 +180,53 @@ impl MessageBody for String { } } } + +pin_project! { + pub(crate) struct MessageBodyMapErr { + #[pin] + body: B, + mapper: Option, + } +} + +impl MessageBodyMapErr +where + B: MessageBody, + F: FnOnce(B::Error) -> E, +{ + pub(crate) fn new(body: B, mapper: F) -> Self { + Self { + body, + mapper: Some(mapper), + } + } +} + +impl MessageBody for MessageBodyMapErr +where + B: MessageBody, + F: FnOnce(B::Error) -> E, +{ + type Error = E; + + fn size(&self) -> BodySize { + self.body.size() + } + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.as_mut().project(); + + match ready!(this.body.poll_next(cx)) { + Some(Err(err)) => { + let f = self.as_mut().project().mapper.take().unwrap(); + let mapped_err = (f)(err); + Poll::Ready(Some(Err(mapped_err))) + } + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + } + } +} diff --git a/actix-http/src/body/mod.rs b/actix-http/src/body/mod.rs index f26d6a8cf..d21cca60b 100644 --- a/actix-http/src/body/mod.rs +++ b/actix-http/src/body/mod.rs @@ -15,9 +15,10 @@ mod response_body; mod size; mod sized_stream; -pub use self::body::Body; +pub use self::body::{Body, BoxAnyBody}; pub use self::body_stream::BodyStream; pub use self::message_body::MessageBody; +pub(crate) use self::message_body::MessageBodyMapErr; pub use self::response_body::ResponseBody; pub use self::size::BodySize; pub use self::sized_stream::SizedStream; @@ -41,7 +42,7 @@ pub use self::sized_stream::SizedStream; /// assert_eq!(bytes, b"123"[..]); /// # } /// ``` -pub async fn to_bytes(body: impl MessageBody) -> Result { +pub async fn to_bytes(body: B) -> Result { let cap = match body.size() { BodySize::None | BodySize::Empty | BodySize::Sized(0) => return Ok(Bytes::new()), BodySize::Sized(size) => size as usize, @@ -237,10 +238,13 @@ mod tests { ); } + // down-casting used to be done with a method on MessageBody trait + // test is kept to demonstrate equivalence of Any trait #[actix_rt::test] async fn test_body_casting() { let mut body = String::from("hello cast"); - let resp_body: &mut dyn MessageBody = &mut body; + // let mut resp_body: &mut dyn MessageBody = &mut body; + let resp_body: &mut dyn std::any::Any = &mut body; let body = resp_body.downcast_ref::().unwrap(); assert_eq!(body, "hello cast"); let body = &mut resp_body.downcast_mut::().unwrap(); diff --git a/actix-http/src/body/response_body.rs b/actix-http/src/body/response_body.rs index b27112475..855c742f2 100644 --- a/actix-http/src/body/response_body.rs +++ b/actix-http/src/body/response_body.rs @@ -5,7 +5,7 @@ use std::{ }; use bytes::Bytes; -use futures_core::Stream; +use futures_core::{ready, Stream}; use pin_project::pin_project; use crate::error::Error; @@ -43,7 +43,13 @@ impl ResponseBody { } } -impl MessageBody for ResponseBody { +impl MessageBody for ResponseBody +where + B: MessageBody, + B::Error: Into, +{ + type Error = Error; + fn size(&self) -> BodySize { match self { ResponseBody::Body(ref body) => body.size(), @@ -54,12 +60,16 @@ impl MessageBody for ResponseBody { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { Stream::poll_next(self, cx) } } -impl Stream for ResponseBody { +impl Stream for ResponseBody +where + B: MessageBody, + B::Error: Into, +{ type Item = Result; fn poll_next( @@ -67,7 +77,12 @@ impl Stream for ResponseBody { cx: &mut Context<'_>, ) -> Poll> { match self.project() { - ResponseBodyProj::Body(body) => body.poll_next(cx), + // TODO: MSRV 1.51: poll_map_err + ResponseBodyProj::Body(body) => match ready!(body.poll_next(cx)) { + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + }, ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), } } diff --git a/actix-http/src/body/sized_stream.rs b/actix-http/src/body/sized_stream.rs index f0332fc8f..4af132389 100644 --- a/actix-http/src/body/sized_stream.rs +++ b/actix-http/src/body/sized_stream.rs @@ -36,6 +36,8 @@ impl MessageBody for SizedStream where S: Stream>, { + type Error = Error; + fn size(&self) -> BodySize { BodySize::Sized(self.size as u64) } @@ -48,7 +50,7 @@ where fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { loop { let stream = self.as_mut().project().stream; diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 623bfdda2..660cd9817 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -202,11 +202,13 @@ where /// Finish service configuration and create a HTTP service for HTTP/2 protocol. pub fn h2(self, service: F) -> H2Service where - B: MessageBody + 'static, F: IntoServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, + + B: MessageBody + 'static, + B::Error: Into, { let cfg = ServiceConfig::new( self.keep_alive, @@ -223,11 +225,13 @@ where /// Finish service configuration and create `HttpService` instance. pub fn finish(self, service: F) -> HttpService where - B: MessageBody + 'static, F: IntoServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, + + B: MessageBody + 'static, + B::Error: Into, { let cfg = ServiceConfig::new( self.keep_alive, diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index 0e3e97f3f..a30f651ca 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -12,10 +12,10 @@ use bytes::Bytes; use futures_core::future::LocalBoxFuture; use h2::client::SendRequest; -use crate::body::MessageBody; use crate::h1::ClientCodec; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; +use crate::{body::MessageBody, Error}; use super::error::SendRequestError; use super::pool::Acquired; @@ -256,8 +256,9 @@ where body: RB, ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> where - RB: MessageBody + 'static, H: Into + 'static, + RB: MessageBody + 'static, + RB::Error: Into, { Box::pin(async move { match self { diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index fa4469d35..65a30748c 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -11,7 +11,6 @@ use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; use futures_util::SinkExt as _; -use crate::error::PayloadError; use crate::h1; use crate::http::{ header::{HeaderMap, IntoHeaderValue, EXPECT, HOST}, @@ -19,6 +18,7 @@ use crate::http::{ }; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; +use crate::{error::PayloadError, Error}; use super::connection::{ConnectionIo, H1Connection}; use super::error::{ConnectError, SendRequestError}; @@ -32,6 +32,7 @@ pub(crate) async fn send_request( where Io: ConnectionIo, B: MessageBody, + B::Error: Into, { // set request host header if !head.as_ref().headers.contains_key(HOST) @@ -154,6 +155,7 @@ pub(crate) async fn send_body( where Io: ConnectionIo, B: MessageBody, + B::Error: Into, { actix_rt::pin!(body); @@ -161,9 +163,10 @@ where while !eof { while !eof && !framed.as_ref().is_write_buf_full() { match poll_fn(|cx| body.as_mut().poll_next(cx)).await { - Some(result) => { - framed.as_mut().write(h1::Message::Chunk(Some(result?)))?; + Some(Ok(chunk)) => { + framed.as_mut().write(h1::Message::Chunk(Some(chunk)))?; } + Some(Err(err)) => return Err(err.into().into()), None => { eof = true; framed.as_mut().write(h1::Message::Chunk(None))?; diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 8cb2e2522..cf423ef12 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -9,14 +9,19 @@ use h2::{ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, Method, Version}; -use crate::body::{BodySize, MessageBody}; -use crate::header::HeaderMap; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::Payload; +use crate::{ + body::{BodySize, MessageBody}, + header::HeaderMap, + message::{RequestHeadType, ResponseHead}, + payload::Payload, + Error, +}; -use super::config::ConnectorConfig; -use super::connection::{ConnectionIo, H2Connection}; -use super::error::SendRequestError; +use super::{ + config::ConnectorConfig, + connection::{ConnectionIo, H2Connection}, + error::SendRequestError, +}; pub(crate) async fn send_request( mut io: H2Connection, @@ -26,6 +31,7 @@ pub(crate) async fn send_request( where Io: ConnectionIo, B: MessageBody, + B::Error: Into, { trace!("Sending client request: {:?} {:?}", head, body.size()); @@ -125,10 +131,14 @@ where Ok((head, payload)) } -async fn send_body( +async fn send_body( body: B, mut send: SendStream, -) -> Result<(), SendRequestError> { +) -> Result<(), SendRequestError> +where + B: MessageBody, + B::Error: Into, +{ let mut buf = None; actix_rt::pin!(body); loop { @@ -138,7 +148,7 @@ async fn send_body( send.reserve_capacity(b.len()); buf = Some(b); } - Some(Err(e)) => return Err(e.into()), + Some(Err(e)) => return Err(e.into().into()), None => { if let Err(e) = send.send_data(Bytes::new(), true) { return Err(e.into()); diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index add6ee980..b8bc8b68d 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,6 +1,7 @@ //! Stream encoders. use std::{ + error::Error as StdError, future::Future, io::{self, Write as _}, pin::Pin, @@ -10,12 +11,13 @@ use std::{ use actix_rt::task::{spawn_blocking, JoinHandle}; use brotli2::write::BrotliEncoder; use bytes::Bytes; +use derive_more::Display; use flate2::write::{GzEncoder, ZlibEncoder}; use futures_core::ready; use pin_project::pin_project; use crate::{ - body::{Body, BodySize, MessageBody, ResponseBody}, + body::{Body, BodySize, BoxAnyBody, MessageBody, ResponseBody}, http::{ header::{ContentEncoding, CONTENT_ENCODING}, HeaderValue, StatusCode, @@ -92,10 +94,16 @@ impl Encoder { enum EncoderBody { Bytes(Bytes), Stream(#[pin] B), - BoxedStream(Pin>), + BoxedStream(BoxAnyBody), } -impl MessageBody for EncoderBody { +impl MessageBody for EncoderBody +where + B: MessageBody, + B::Error: Into, +{ + type Error = EncoderError; + fn size(&self) -> BodySize { match self { EncoderBody::Bytes(ref b) => b.size(), @@ -107,7 +115,7 @@ impl MessageBody for EncoderBody { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match self.project() { EncoderBodyProj::Bytes(b) => { if b.is_empty() { @@ -116,13 +124,32 @@ impl MessageBody for EncoderBody { Poll::Ready(Some(Ok(std::mem::take(b)))) } } - EncoderBodyProj::Stream(b) => b.poll_next(cx), - EncoderBodyProj::BoxedStream(ref mut b) => b.as_mut().poll_next(cx), + // TODO: MSRV 1.51: poll_map_err + EncoderBodyProj::Stream(b) => match ready!(b.poll_next(cx)) { + Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Body(err)))), + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + }, + EncoderBodyProj::BoxedStream(ref mut b) => { + match ready!(b.as_pin_mut().poll_next(cx)) { + Some(Err(err)) => { + Poll::Ready(Some(Err(EncoderError::Boxed(err.into())))) + } + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + } + } } } } -impl MessageBody for Encoder { +impl MessageBody for Encoder +where + B: MessageBody, + B::Error: Into, +{ + type Error = EncoderError; + fn size(&self) -> BodySize { if self.encoder.is_none() { self.body.size() @@ -134,7 +161,7 @@ impl MessageBody for Encoder { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let mut this = self.project(); loop { if *this.eof { @@ -142,8 +169,9 @@ impl MessageBody for Encoder { } if let Some(ref mut fut) = this.fut { - let mut encoder = - ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + let mut encoder = ready!(Pin::new(fut).poll(cx)) + .map_err(|_| EncoderError::Blocking(BlockingError))? + .map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -162,7 +190,7 @@ impl MessageBody for Encoder { Some(Ok(chunk)) => { if let Some(mut encoder) = this.encoder.take() { if chunk.len() < MAX_CHUNK_SIZE_ENCODE_IN_PLACE { - encoder.write(&chunk)?; + encoder.write(&chunk).map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -182,7 +210,7 @@ impl MessageBody for Encoder { None => { if let Some(encoder) = this.encoder.take() { - let chunk = encoder.finish()?; + let chunk = encoder.finish().map_err(EncoderError::Io)?; if chunk.is_empty() { return Poll::Ready(None); } else { @@ -281,3 +309,36 @@ impl ContentEncoder { } } } + +#[derive(Debug, Display)] +#[non_exhaustive] +pub enum EncoderError { + #[display(fmt = "body")] + Body(E), + + #[display(fmt = "boxed")] + Boxed(Error), + + #[display(fmt = "blocking")] + Blocking(BlockingError), + + #[display(fmt = "io")] + Io(io::Error), +} + +impl StdError for EncoderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + None + } +} + +impl> From> for Error { + fn from(err: EncoderError) -> Self { + match err { + EncoderError::Body(err) => err.into(), + EncoderError::Boxed(err) => err, + EncoderError::Blocking(err) => err.into(), + EncoderError::Io(err) => err.into(), + } + } +} diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 39ffa29e7..c92f9076d 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -2,6 +2,7 @@ use std::{ cell::RefCell, + error::Error as StdError, fmt, io::{self, Write as _}, str::Utf8Error, @@ -105,8 +106,7 @@ impl From<()> for Error { impl From for Error { fn from(_: std::convert::Infallible) -> Self { - // `std::convert::Infallible` indicates an error - // that will never happen + // hint that an error that will never happen unreachable!() } } @@ -145,6 +145,8 @@ impl From for Error { #[display(fmt = "Unknown Error")] struct UnitError; +impl ResponseError for Box {} + /// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`UnitError`]. impl ResponseError for UnitError {} diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 3b272f0fb..7ab89ba87 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -51,9 +51,13 @@ pub struct Dispatcher where S: Service, S::Error: Into, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -69,9 +73,13 @@ enum DispatcherState where S: Service, S::Error: Into, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -84,9 +92,13 @@ struct InnerDispatcher where S: Service, S::Error: Into, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -122,7 +134,9 @@ enum State where S: Service, X: Service, + B: MessageBody, + B::Error: Into, { None, ExpectCall(#[pin] X::Future), @@ -133,8 +147,11 @@ where impl State where S: Service, + X: Service, + B: MessageBody, + B::Error: Into, { fn is_empty(&self) -> bool { matches!(self, State::None) @@ -150,12 +167,17 @@ enum PollResponse { impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -206,12 +228,17 @@ where impl InnerDispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -817,12 +844,17 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 916643a18..1ab85cbf3 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -64,11 +64,15 @@ where S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: ServiceFactory, X::Future: 'static, X::Error: Into, X::InitError: fmt::Debug, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Future: 'static, U::Error: fmt::Display + Into, @@ -109,11 +113,15 @@ mod openssl { S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: ServiceFactory, X::Future: 'static, X::Error: Into, X::InitError: fmt::Debug, + U: ServiceFactory< (Request, Framed, Codec>), Config = (), @@ -165,11 +173,15 @@ mod rustls { S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: ServiceFactory, X::Future: 'static, X::Error: Into, X::InitError: fmt::Debug, + U: ServiceFactory< (Request, Framed, Codec>), Config = (), @@ -253,16 +265,21 @@ impl ServiceFactory<(T, Option)> for H1Service where T: AsyncRead + AsyncWrite + Unpin + 'static, + S: ServiceFactory, S::Future: 'static, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, + B: MessageBody, + B::Error: Into, + X: ServiceFactory, X::Future: 'static, X::Error: Into, X::InitError: fmt::Debug, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Future: 'static, U::Error: fmt::Display + Into, @@ -319,12 +336,17 @@ impl Service<(T, Option)> for HttpServiceHandler where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into, S::Response: Into>, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display + Into, { diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index 9e9c57137..73f01e913 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -22,6 +22,7 @@ pub struct SendResponse { impl SendResponse where B: MessageBody, + B::Error: Into, { pub fn new(framed: Framed, response: Response) -> Self { let (res, body) = response.into_parts(); @@ -38,6 +39,7 @@ impl Future for SendResponse where T: AsyncRead + AsyncWrite + Unpin, B: MessageBody + Unpin, + B::Error: Into, { type Output = Result, Error>; diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 87dd66fe7..07636470b 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -69,11 +69,14 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, + B: MessageBody + 'static, + B::Error: Into, { type Output = Result<(), DispatchError>; @@ -140,7 +143,9 @@ where F: Future>, E: Into, I: Into>, + B: MessageBody, + B::Error: Into, { fn prepare_response( &self, @@ -216,7 +221,9 @@ where F: Future>, E: Into, I: Into>, + B: MessageBody, + B::Error: Into, { type Output = (); diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 1a0b8c7f5..a75abef7d 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -40,7 +40,9 @@ where S::Error: Into + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + B::Error: Into, { /// Create new `H2Service` instance with config. pub(crate) fn with_config>( @@ -69,7 +71,9 @@ where S::Error: Into + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + B::Error: Into, { /// Create plain TCP based service pub fn tcp( @@ -106,7 +110,9 @@ mod openssl { S::Error: Into + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + B::Error: Into, { /// Create OpenSSL based service pub fn openssl( @@ -150,7 +156,9 @@ mod rustls { S::Error: Into + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + B::Error: Into, { /// Create Rustls based service pub fn rustls( @@ -185,12 +193,15 @@ mod rustls { impl ServiceFactory<(T, Option)> for H2Service where T: AsyncRead + AsyncWrite + Unpin + 'static, + S: ServiceFactory, S::Future: 'static, S::Error: Into + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + B::Error: Into, { type Response = (); type Error = DispatchError; @@ -252,6 +263,7 @@ where S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, + B::Error: Into, { type Response = (); type Error = DispatchError; @@ -316,6 +328,7 @@ where S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, + B::Error: Into, { type Output = Result<(), DispatchError>; diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index e11ceb18f..da5c7e000 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -242,7 +242,11 @@ impl Response { } } -impl fmt::Debug for Response { +impl fmt::Debug for Response +where + B: MessageBody, + B::Error: Into, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( f, diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index ff4b49f1d..d25a67a19 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -59,6 +59,7 @@ where S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, { /// Create new `HttpService` instance. pub fn new>(service: F) -> Self { @@ -157,6 +158,7 @@ where >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, X: ServiceFactory, X::Future: 'static, @@ -208,6 +210,7 @@ mod openssl { >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, X: ServiceFactory, X::Future: 'static, @@ -275,6 +278,7 @@ mod rustls { >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, X: ServiceFactory, X::Future: 'static, @@ -339,6 +343,7 @@ where >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, X: ServiceFactory, X::Future: 'static, @@ -465,13 +470,18 @@ impl Service<(T, Protocol, Option)> for HttpServiceHandler where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, + B: MessageBody + 'static, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display + Into, { @@ -522,13 +532,18 @@ where #[pin_project(project = StateProj)] enum State where + T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Future: 'static, S::Error: Into, - T: AsyncRead + AsyncWrite + Unpin, + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -549,13 +564,18 @@ where pub struct HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, - B: MessageBody + 'static, + + B: MessageBody, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -566,13 +586,18 @@ where impl Future for HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, - B: MessageBody, + + B: MessageBody + 'static, + B::Error: Into, + X: Service, X::Error: Into, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { diff --git a/actix-test/src/lib.rs b/actix-test/src/lib.rs index 8fab33289..5d85c2687 100644 --- a/actix-test/src/lib.rs +++ b/actix-test/src/lib.rs @@ -86,6 +86,7 @@ where S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, { start_with(TestServerConfig::default(), factory) } @@ -125,6 +126,7 @@ where S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, + B::Error: Into, { let (tx, rx) = mpsc::channel(); diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs index 0e3a4f2b7..3a85591da 100644 --- a/src/middleware/compat.rs +++ b/src/middleware/compat.rs @@ -113,7 +113,11 @@ pub trait MapServiceResponseBody { fn map_body(self) -> ServiceResponse; } -impl MapServiceResponseBody for ServiceResponse { +impl MapServiceResponseBody for ServiceResponse +where + B: MessageBody + Unpin + 'static, + B::Error: Into, +{ fn map_body(self) -> ServiceResponse { self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 40ed9258f..8a60d6c70 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -22,10 +22,9 @@ use time::OffsetDateTime; use crate::{ dev::{BodySize, MessageBody, ResponseBody}, - error::{Error, Result}, http::{HeaderName, StatusCode}, service::{ServiceRequest, ServiceResponse}, - HttpResponse, + Error, HttpResponse, Result, }; /// Middleware for logging request and response summaries to the terminal. @@ -327,7 +326,13 @@ impl PinnedDrop for StreamLog { } } -impl MessageBody for StreamLog { +impl MessageBody for StreamLog +where + B: MessageBody, + B::Error: Into, +{ + type Error = Error; + fn size(&self) -> BodySize { self.body.size() } @@ -335,7 +340,7 @@ impl MessageBody for StreamLog { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let this = self.project(); match this.body.poll_next(cx) { Poll::Ready(Some(Ok(chunk))) => { diff --git a/src/response/response.rs b/src/response/response.rs index 31868fe0b..6e09a0136 100644 --- a/src/response/response.rs +++ b/src/response/response.rs @@ -243,7 +243,11 @@ impl HttpResponse { } } -impl fmt::Debug for HttpResponse { +impl fmt::Debug for HttpResponse +where + B: MessageBody, + B::Error: Into, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HttpResponse") .field("error", &self.error) diff --git a/src/server.rs b/src/server.rs index 6577f4d1f..6e11c642f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -81,6 +81,7 @@ where S::Service: 'static, // S::Service: 'static, B: MessageBody + 'static, + B::Error: Into, { /// Create new HTTP server with application factory pub fn new(factory: F) -> Self { diff --git a/src/service.rs b/src/service.rs index f6d1f9ebf..0c03f84ad 100644 --- a/src/service.rs +++ b/src/service.rs @@ -443,7 +443,11 @@ impl From> for Response { } } -impl fmt::Debug for ServiceResponse { +impl fmt::Debug for ServiceResponse +where + B: MessageBody, + B::Error: Into, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( f, diff --git a/src/test.rs b/src/test.rs index c2e456e58..9fe5e9b5d 100644 --- a/src/test.rs +++ b/src/test.rs @@ -151,6 +151,7 @@ pub async fn read_response(app: &S, req: Request) -> Bytes where S: Service, Error = Error>, B: MessageBody + Unpin, + B::Error: Into, { let mut resp = app .call(req) @@ -196,6 +197,7 @@ where pub async fn read_body(mut res: ServiceResponse) -> Bytes where B: MessageBody + Unpin, + B::Error: Into, { let mut body = res.take_body(); let mut bytes = BytesMut::new(); @@ -245,6 +247,7 @@ where pub async fn read_body_json(res: ServiceResponse) -> T where B: MessageBody + Unpin, + B::Error: Into, T: DeserializeOwned, { let body = read_body(res).await; @@ -306,6 +309,7 @@ pub async fn read_response_json(app: &S, req: Request) -> T where S: Service, Error = Error>, B: MessageBody + Unpin, + B::Error: Into, T: DeserializeOwned, { let body = read_response(app, req).await;