From dd4a372613339f6668dc248d2dbc414c3a6ccfad Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Tue, 14 Dec 2021 21:17:50 +0000 Subject: [PATCH] allow error handler middleware to return different body type (#2515) --- CHANGES.md | 3 + actix-router/src/url.rs | 53 ++++++++------ scripts/ci-test | 32 ++++++--- src/middleware/compat.rs | 22 +++--- src/middleware/condition.rs | 13 ++-- src/middleware/err_handlers.rs | 128 +++++++++++++++++++++------------ src/middleware/mod.rs | 4 +- 7 files changed, 162 insertions(+), 93 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b8d3ce8de..0c27aaa1c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,12 +7,15 @@ ### Changed * Align `DefaultHeader` method terminology, deprecating previous methods. [#2510] +* Response service types in `ErrorHandlers` middleware now use `ServiceResponse>` to allow changing the body type. [#2515] +* Both variants in `ErrorHandlerResponse` now use `ServiceResponse>`. [#2515] ### Removed * Top-level `EitherExtractError` export. [#2510] * Conversion implementations for `either` crate. [#2516] [#2510]: https://github.com/actix/actix-web/pull/2510 +[#2515]: https://github.com/actix/actix-web/pull/2515 [#2516]: https://github.com/actix/actix-web/pull/2516 diff --git a/actix-router/src/url.rs b/actix-router/src/url.rs index e08a7171a..10193dde8 100644 --- a/actix-router/src/url.rs +++ b/actix-router/src/url.rs @@ -2,22 +2,28 @@ use crate::ResourcePath; #[allow(dead_code)] const GEN_DELIMS: &[u8] = b":/?#[]@"; + #[allow(dead_code)] const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,"; + #[allow(dead_code)] const SUB_DELIMS: &[u8] = b"!$'()*,+?=;"; + #[allow(dead_code)] const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;"; + #[allow(dead_code)] const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ 1234567890 -._~"; + const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ 1234567890 -._~ !$'()*,"; + const QS: &[u8] = b"+&=;b"; #[inline] @@ -34,19 +40,20 @@ thread_local! { static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+"); } -#[derive(Default, Clone, Debug)] +#[derive(Debug, Clone, Default)] pub struct Url { uri: http::Uri, path: Option, } impl Url { + #[inline] pub fn new(uri: http::Uri) -> Url { let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); - Url { uri, path } } + #[inline] pub fn with_quoter(uri: http::Uri, quoter: &Quoter) -> Url { Url { path: quoter.requote(uri.path().as_bytes()), @@ -54,15 +61,16 @@ impl Url { } } + #[inline] pub fn uri(&self) -> &http::Uri { &self.uri } + #[inline] pub fn path(&self) -> &str { - if let Some(ref s) = self.path { - s - } else { - self.uri.path() + match self.path { + Some(ref path) => path, + _ => self.uri.path(), } } @@ -86,6 +94,7 @@ impl ResourcePath for Url { } } +/// A quoter pub struct Quoter { safe_table: [u8; 16], protected_table: [u8; 16], @@ -93,7 +102,7 @@ pub struct Quoter { impl Quoter { pub fn new(safe: &[u8], protected: &[u8]) -> Quoter { - let mut q = Quoter { + let mut quoter = Quoter { safe_table: [0; 16], protected_table: [0; 16], }; @@ -101,24 +110,24 @@ impl Quoter { // prepare safe table for i in 0..128 { if ALLOWED.contains(&i) { - set_bit(&mut q.safe_table, i); + set_bit(&mut quoter.safe_table, i); } if QS.contains(&i) { - set_bit(&mut q.safe_table, i); + set_bit(&mut quoter.safe_table, i); } } for ch in safe { - set_bit(&mut q.safe_table, *ch) + set_bit(&mut quoter.safe_table, *ch) } // prepare protected table for ch in protected { - set_bit(&mut q.safe_table, *ch); - set_bit(&mut q.protected_table, *ch); + set_bit(&mut quoter.safe_table, *ch); + set_bit(&mut quoter.protected_table, *ch); } - q + quoter } pub fn requote(&self, val: &[u8]) -> Option { @@ -215,7 +224,7 @@ mod tests { } #[test] - fn test_parse_url() { + fn parse_url() { let re = "/user/{id}/test"; let path = match_url(re, "/user/2345/test"); @@ -231,24 +240,24 @@ mod tests { } #[test] - fn test_protected_chars() { + fn protected_chars() { let encoded = percent_encode(PROTECTED); let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); assert_eq!(path.get("id").unwrap(), &encoded); } #[test] - fn test_non_protecteed_ascii() { - let nonprotected_ascii = ('\u{0}'..='\u{7F}') + fn non_protected_ascii() { + let non_protected_ascii = ('\u{0}'..='\u{7F}') .filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8))) .collect::(); - let encoded = percent_encode(nonprotected_ascii.as_bytes()); + let encoded = percent_encode(non_protected_ascii.as_bytes()); let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); - assert_eq!(path.get("id").unwrap(), &nonprotected_ascii); + assert_eq!(path.get("id").unwrap(), &non_protected_ascii); } #[test] - fn test_valid_utf8_multibyte() { + fn valid_utf8_multibyte() { let test = ('\u{FF00}'..='\u{FFFF}').collect::(); let encoded = percent_encode(test.as_bytes()); let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded)); @@ -256,7 +265,7 @@ mod tests { } #[test] - fn test_invalid_utf8() { + fn invalid_utf8() { let invalid_utf8 = percent_encode((0x80..=0xff).collect::>().as_slice()); let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap(); let path = Path::new(Url::new(uri)); @@ -266,7 +275,7 @@ mod tests { } #[test] - fn test_from_hex() { + fn hex_encoding() { let hex = b"0123456789abcdefABCDEF"; for i in 0..256 { diff --git a/scripts/ci-test b/scripts/ci-test index 98e13927d..3ab229665 100755 --- a/scripts/ci-test +++ b/scripts/ci-test @@ -4,15 +4,25 @@ set -x -cargo test --lib --tests -p=actix-router --all-features -cargo test --lib --tests -p=actix-http --all-features -cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls -cargo test --lib --tests -p=actix-web-codegen --all-features -cargo test --lib --tests -p=awc --all-features -cargo test --lib --tests -p=actix-http-test --all-features -cargo test --lib --tests -p=actix-test --all-features -cargo test --lib --tests -p=actix-files -cargo test --lib --tests -p=actix-multipart --all-features -cargo test --lib --tests -p=actix-web-actors --all-features +EXIT=0 -cargo test --workspace --doc +save_exit_code() { + eval $@ + local CMD_EXIT=$? + [ "$CMD_EXIT" = "0" ] || EXIT=$CMD_EXIT +} + +save_exit_code cargo test --lib --tests -p=actix-router --all-features +save_exit_code cargo test --lib --tests -p=actix-http --all-features +save_exit_code cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls +save_exit_code cargo test --lib --tests -p=actix-web-codegen --all-features +save_exit_code cargo test --lib --tests -p=awc --all-features +save_exit_code cargo test --lib --tests -p=actix-http-test --all-features +save_exit_code cargo test --lib --tests -p=actix-test --all-features +save_exit_code cargo test --lib --tests -p=actix-files +save_exit_code cargo test --lib --tests -p=actix-multipart --all-features +save_exit_code cargo test --lib --tests -p=actix-web-actors --all-features + +save_exit_code cargo test --workspace --doc + +exit $EXIT diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs index ed441f7b9..d49c461c4 100644 --- a/src/middleware/compat.rs +++ b/src/middleware/compat.rs @@ -6,12 +6,15 @@ use std::{ task::{Context, Poll}, }; -use actix_http::body::MessageBody; -use actix_service::{Service, Transform}; use futures_core::{future::LocalBoxFuture, ready}; use pin_project_lite::pin_project; -use crate::{error::Error, service::ServiceResponse}; +use crate::{ + body::{BoxBody, MessageBody}, + dev::{Service, Transform}, + error::Error, + service::ServiceResponse, +}; /// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap), /// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition). @@ -52,7 +55,7 @@ where T::Response: MapServiceResponseBody, T::Error: Into, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Transform = CompatMiddleware; type InitError = T::InitError; @@ -77,7 +80,7 @@ where S::Response: MapServiceResponseBody, S::Error: Into, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = CompatMiddlewareFuture; @@ -102,7 +105,7 @@ where T: MapServiceResponseBody, E: Into, { - type Output = Result; + type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = match ready!(self.project().fut.poll(cx)) { @@ -116,14 +119,15 @@ where /// Convert `ServiceResponse`'s `ResponseBody` generic type to `ResponseBody`. pub trait MapServiceResponseBody { - fn map_body(self) -> ServiceResponse; + fn map_body(self) -> ServiceResponse; } impl MapServiceResponseBody for ServiceResponse where - B: MessageBody + Unpin + 'static, + B: MessageBody + 'static, { - fn map_body(self) -> ServiceResponse { + #[inline] + fn map_body(self) -> ServiceResponse { self.map_into_boxed_body() } } diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index a7777a96b..659f88bc9 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -106,7 +106,7 @@ mod tests { header::{HeaderValue, CONTENT_TYPE}, StatusCode, }, - middleware::err_handlers::*, + middleware::{err_handlers::*, Compat}, test::{self, TestRequest}, HttpResponse, }; @@ -116,7 +116,8 @@ mod tests { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Response(res)) + + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) } #[actix_rt::test] @@ -125,7 +126,9 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = Compat::new( + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) @@ -141,7 +144,9 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = Compat::new( + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) diff --git a/src/middleware/err_handlers.rs b/src/middleware/err_handlers.rs index 756da30c3..fedefa6fa 100644 --- a/src/middleware/err_handlers.rs +++ b/src/middleware/err_handlers.rs @@ -13,6 +13,7 @@ use futures_core::{future::LocalBoxFuture, ready}; use pin_project_lite::pin_project; use crate::{ + body::EitherBody, dev::{ServiceRequest, ServiceResponse}, http::StatusCode, Error, Result, @@ -21,10 +22,10 @@ use crate::{ /// Return type for [`ErrorHandlers`] custom handlers. pub enum ErrorHandlerResponse { /// Immediate HTTP response. - Response(ServiceResponse), + Response(ServiceResponse>), /// A future that resolves to an HTTP response. - Future(LocalBoxFuture<'static, Result, Error>>), + Future(LocalBoxFuture<'static, Result>, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; @@ -44,7 +45,8 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result = Rc>>>; impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { - handlers: Rc::new(AHashMap::default()), + handlers: Default::default(), } } } @@ -95,7 +97,7 @@ where S::Future: 'static, B: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; type Transform = ErrorHandlersMiddleware; type InitError = (); @@ -119,7 +121,7 @@ where S::Future: 'static, B: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; type Future = ErrorHandlersFuture; @@ -143,8 +145,8 @@ pin_project! { fut: Fut, handlers: Handlers, }, - HandlerFuture { - fut: LocalBoxFuture<'static, Fut::Output>, + ErrorHandlerFuture { + fut: LocalBoxFuture<'static, Result>, Error>>, }, } } @@ -153,25 +155,29 @@ impl Future for ErrorHandlersFuture where Fut: Future, Error>>, { - type Output = Fut::Output; + type Output = Result>, Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { ErrorHandlersProj::ServiceFuture { fut, handlers } => { let res = ready!(fut.poll(cx))?; + match handlers.get(&res.status()) { Some(handler) => match handler(res)? { ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Future(fut) => { self.as_mut() - .set(ErrorHandlersFuture::HandlerFuture { fut }); + .set(ErrorHandlersFuture::ErrorHandlerFuture { fut }); + self.poll(cx) } }, - None => Poll::Ready(Ok(res)), + + None => Poll::Ready(Ok(res.map_into_left_body())), } } - ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx), + + ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx), } } } @@ -180,32 +186,33 @@ where mod tests { use actix_service::IntoService; use actix_utils::future::ok; + use bytes::Bytes; use futures_util::future::FutureExt as _; use super::*; - use crate::http::{ - header::{HeaderValue, CONTENT_TYPE}, - StatusCode, + use crate::{ + http::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + test::{self, TestRequest}, }; - use crate::test::{self, TestRequest}; - use crate::HttpResponse; - - #[allow(clippy::unnecessary_wraps)] - fn render_500(mut res: ServiceResponse) -> Result> { - res.response_mut() - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Response(res)) - } #[actix_rt::test] - async fn test_handler() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) - }; + async fn add_header_error_handler() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR); let mw = ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) .new_transform(srv.into_service()) .await .unwrap(); @@ -214,24 +221,25 @@ mod tests { assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } - #[allow(clippy::unnecessary_wraps)] - fn render_500_async( - mut res: ServiceResponse, - ) -> Result> { - res.response_mut() - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) - } - #[actix_rt::test] - async fn test_handler_async() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) - }; + async fn add_header_error_handler_async() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + + Ok(ErrorHandlerResponse::Future( + ok(res.map_into_left_body()).boxed_local(), + )) + } + + let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR); let mw = ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) .new_transform(srv.into_service()) .await .unwrap(); @@ -239,4 +247,34 @@ mod tests { let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } + + #[actix_rt::test] + async fn changes_body_type() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler( + res: ServiceResponse, + ) -> Result> { + let (req, res) = res.into_parts(); + let res = res.set_body(Bytes::from("sorry, that's no bueno")); + + let res = ServiceResponse::new(req, res) + .map_into_boxed_body() + .map_into_right_body(); + + Ok(ErrorHandlerResponse::Response(res)) + } + + let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR); + + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(test::read_body(res).await, "sorry, that's no bueno"); + } + + // TODO: test where error is thrown } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 42d285580..0da9b9b2e 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -35,7 +35,7 @@ mod tests { .wrap(Condition::new(true, DefaultHeaders::new())) .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { - Ok(ErrorHandlerResponse::Response(res)) + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) })) .wrap(Logger::default()) .wrap(NormalizePath::new(TrailingSlash::Trim)); @@ -44,7 +44,7 @@ mod tests { .wrap(NormalizePath::new(TrailingSlash::Trim)) .wrap(Logger::default()) .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { - Ok(ErrorHandlerResponse::Response(res)) + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) })) .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) .wrap(Condition::new(true, DefaultHeaders::new()))