1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-27 17:52:56 +01:00

allow error handler middleware to return different body type (#2515)

This commit is contained in:
Rob Ede 2021-12-14 21:17:50 +00:00 committed by GitHub
parent 05255c7f7c
commit dd4a372613
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 162 additions and 93 deletions

View File

@ -7,12 +7,15 @@
### Changed ### Changed
* Align `DefaultHeader` method terminology, deprecating previous methods. [#2510] * Align `DefaultHeader` method terminology, deprecating previous methods. [#2510]
* Response service types in `ErrorHandlers` middleware now use `ServiceResponse<EitherBody<B>>` to allow changing the body type. [#2515]
* Both variants in `ErrorHandlerResponse` now use `ServiceResponse<EitherBody<B>>`. [#2515]
### Removed ### Removed
* Top-level `EitherExtractError` export. [#2510] * Top-level `EitherExtractError` export. [#2510]
* Conversion implementations for `either` crate. [#2516] * Conversion implementations for `either` crate. [#2516]
[#2510]: https://github.com/actix/actix-web/pull/2510 [#2510]: https://github.com/actix/actix-web/pull/2510
[#2515]: https://github.com/actix/actix-web/pull/2515
[#2516]: https://github.com/actix/actix-web/pull/2516 [#2516]: https://github.com/actix/actix-web/pull/2516

View File

@ -2,22 +2,28 @@ use crate::ResourcePath;
#[allow(dead_code)] #[allow(dead_code)]
const GEN_DELIMS: &[u8] = b":/?#[]@"; const GEN_DELIMS: &[u8] = b":/?#[]@";
#[allow(dead_code)] #[allow(dead_code)]
const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,"; const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,";
#[allow(dead_code)] #[allow(dead_code)]
const SUB_DELIMS: &[u8] = b"!$'()*,+?=;"; const SUB_DELIMS: &[u8] = b"!$'()*,+?=;";
#[allow(dead_code)] #[allow(dead_code)]
const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;"; const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;";
#[allow(dead_code)] #[allow(dead_code)]
const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
1234567890 1234567890
-._~"; -._~";
const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
1234567890 1234567890
-._~ -._~
!$'()*,"; !$'()*,";
const QS: &[u8] = b"+&=;b"; const QS: &[u8] = b"+&=;b";
#[inline] #[inline]
@ -34,19 +40,20 @@ thread_local! {
static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+"); static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+");
} }
#[derive(Default, Clone, Debug)] #[derive(Debug, Clone, Default)]
pub struct Url { pub struct Url {
uri: http::Uri, uri: http::Uri,
path: Option<String>, path: Option<String>,
} }
impl Url { impl Url {
#[inline]
pub fn new(uri: http::Uri) -> Url { pub fn new(uri: http::Uri) -> Url {
let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes()));
Url { uri, path } Url { uri, path }
} }
#[inline]
pub fn with_quoter(uri: http::Uri, quoter: &Quoter) -> Url { pub fn with_quoter(uri: http::Uri, quoter: &Quoter) -> Url {
Url { Url {
path: quoter.requote(uri.path().as_bytes()), path: quoter.requote(uri.path().as_bytes()),
@ -54,15 +61,16 @@ impl Url {
} }
} }
#[inline]
pub fn uri(&self) -> &http::Uri { pub fn uri(&self) -> &http::Uri {
&self.uri &self.uri
} }
#[inline]
pub fn path(&self) -> &str { pub fn path(&self) -> &str {
if let Some(ref s) = self.path { match self.path {
s Some(ref path) => path,
} else { _ => self.uri.path(),
self.uri.path()
} }
} }
@ -86,6 +94,7 @@ impl ResourcePath for Url {
} }
} }
/// A quoter
pub struct Quoter { pub struct Quoter {
safe_table: [u8; 16], safe_table: [u8; 16],
protected_table: [u8; 16], protected_table: [u8; 16],
@ -93,7 +102,7 @@ pub struct Quoter {
impl Quoter { impl Quoter {
pub fn new(safe: &[u8], protected: &[u8]) -> Quoter { pub fn new(safe: &[u8], protected: &[u8]) -> Quoter {
let mut q = Quoter { let mut quoter = Quoter {
safe_table: [0; 16], safe_table: [0; 16],
protected_table: [0; 16], protected_table: [0; 16],
}; };
@ -101,24 +110,24 @@ impl Quoter {
// prepare safe table // prepare safe table
for i in 0..128 { for i in 0..128 {
if ALLOWED.contains(&i) { if ALLOWED.contains(&i) {
set_bit(&mut q.safe_table, i); set_bit(&mut quoter.safe_table, i);
} }
if QS.contains(&i) { if QS.contains(&i) {
set_bit(&mut q.safe_table, i); set_bit(&mut quoter.safe_table, i);
} }
} }
for ch in safe { for ch in safe {
set_bit(&mut q.safe_table, *ch) set_bit(&mut quoter.safe_table, *ch)
} }
// prepare protected table // prepare protected table
for ch in protected { for ch in protected {
set_bit(&mut q.safe_table, *ch); set_bit(&mut quoter.safe_table, *ch);
set_bit(&mut q.protected_table, *ch); set_bit(&mut quoter.protected_table, *ch);
} }
q quoter
} }
pub fn requote(&self, val: &[u8]) -> Option<String> { pub fn requote(&self, val: &[u8]) -> Option<String> {
@ -215,7 +224,7 @@ mod tests {
} }
#[test] #[test]
fn test_parse_url() { fn parse_url() {
let re = "/user/{id}/test"; let re = "/user/{id}/test";
let path = match_url(re, "/user/2345/test"); let path = match_url(re, "/user/2345/test");
@ -231,24 +240,24 @@ mod tests {
} }
#[test] #[test]
fn test_protected_chars() { fn protected_chars() {
let encoded = percent_encode(PROTECTED); let encoded = percent_encode(PROTECTED);
let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded));
assert_eq!(path.get("id").unwrap(), &encoded); assert_eq!(path.get("id").unwrap(), &encoded);
} }
#[test] #[test]
fn test_non_protecteed_ascii() { fn non_protected_ascii() {
let nonprotected_ascii = ('\u{0}'..='\u{7F}') let non_protected_ascii = ('\u{0}'..='\u{7F}')
.filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8))) .filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8)))
.collect::<String>(); .collect::<String>();
let encoded = percent_encode(nonprotected_ascii.as_bytes()); let encoded = percent_encode(non_protected_ascii.as_bytes());
let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded));
assert_eq!(path.get("id").unwrap(), &nonprotected_ascii); assert_eq!(path.get("id").unwrap(), &non_protected_ascii);
} }
#[test] #[test]
fn test_valid_utf8_multibyte() { fn valid_utf8_multibyte() {
let test = ('\u{FF00}'..='\u{FFFF}').collect::<String>(); let test = ('\u{FF00}'..='\u{FFFF}').collect::<String>();
let encoded = percent_encode(test.as_bytes()); let encoded = percent_encode(test.as_bytes());
let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded)); let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded));
@ -256,7 +265,7 @@ mod tests {
} }
#[test] #[test]
fn test_invalid_utf8() { fn invalid_utf8() {
let invalid_utf8 = percent_encode((0x80..=0xff).collect::<Vec<_>>().as_slice()); let invalid_utf8 = percent_encode((0x80..=0xff).collect::<Vec<_>>().as_slice());
let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap(); let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap();
let path = Path::new(Url::new(uri)); let path = Path::new(Url::new(uri));
@ -266,7 +275,7 @@ mod tests {
} }
#[test] #[test]
fn test_from_hex() { fn hex_encoding() {
let hex = b"0123456789abcdefABCDEF"; let hex = b"0123456789abcdefABCDEF";
for i in 0..256 { for i in 0..256 {

View File

@ -4,15 +4,25 @@
set -x set -x
cargo test --lib --tests -p=actix-router --all-features EXIT=0
cargo test --lib --tests -p=actix-http --all-features
cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls
cargo test --lib --tests -p=actix-web-codegen --all-features
cargo test --lib --tests -p=awc --all-features
cargo test --lib --tests -p=actix-http-test --all-features
cargo test --lib --tests -p=actix-test --all-features
cargo test --lib --tests -p=actix-files
cargo test --lib --tests -p=actix-multipart --all-features
cargo test --lib --tests -p=actix-web-actors --all-features
cargo test --workspace --doc save_exit_code() {
eval $@
local CMD_EXIT=$?
[ "$CMD_EXIT" = "0" ] || EXIT=$CMD_EXIT
}
save_exit_code cargo test --lib --tests -p=actix-router --all-features
save_exit_code cargo test --lib --tests -p=actix-http --all-features
save_exit_code cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls
save_exit_code cargo test --lib --tests -p=actix-web-codegen --all-features
save_exit_code cargo test --lib --tests -p=awc --all-features
save_exit_code cargo test --lib --tests -p=actix-http-test --all-features
save_exit_code cargo test --lib --tests -p=actix-test --all-features
save_exit_code cargo test --lib --tests -p=actix-files
save_exit_code cargo test --lib --tests -p=actix-multipart --all-features
save_exit_code cargo test --lib --tests -p=actix-web-actors --all-features
save_exit_code cargo test --workspace --doc
exit $EXIT

View File

@ -6,12 +6,15 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use actix_http::body::MessageBody;
use actix_service::{Service, Transform};
use futures_core::{future::LocalBoxFuture, ready}; use futures_core::{future::LocalBoxFuture, ready};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use crate::{error::Error, service::ServiceResponse}; use crate::{
body::{BoxBody, MessageBody},
dev::{Service, Transform},
error::Error,
service::ServiceResponse,
};
/// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap), /// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap),
/// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition). /// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition).
@ -52,7 +55,7 @@ where
T::Response: MapServiceResponseBody, T::Response: MapServiceResponseBody,
T::Error: Into<Error>, T::Error: Into<Error>,
{ {
type Response = ServiceResponse; type Response = ServiceResponse<BoxBody>;
type Error = Error; type Error = Error;
type Transform = CompatMiddleware<T::Transform>; type Transform = CompatMiddleware<T::Transform>;
type InitError = T::InitError; type InitError = T::InitError;
@ -77,7 +80,7 @@ where
S::Response: MapServiceResponseBody, S::Response: MapServiceResponseBody,
S::Error: Into<Error>, S::Error: Into<Error>,
{ {
type Response = ServiceResponse; type Response = ServiceResponse<BoxBody>;
type Error = Error; type Error = Error;
type Future = CompatMiddlewareFuture<S::Future>; type Future = CompatMiddlewareFuture<S::Future>;
@ -102,7 +105,7 @@ where
T: MapServiceResponseBody, T: MapServiceResponseBody,
E: Into<Error>, E: Into<Error>,
{ {
type Output = Result<ServiceResponse, Error>; type Output = Result<ServiceResponse<BoxBody>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match ready!(self.project().fut.poll(cx)) { let res = match ready!(self.project().fut.poll(cx)) {
@ -116,14 +119,15 @@ where
/// Convert `ServiceResponse`'s `ResponseBody<B>` generic type to `ResponseBody<Body>`. /// Convert `ServiceResponse`'s `ResponseBody<B>` generic type to `ResponseBody<Body>`.
pub trait MapServiceResponseBody { pub trait MapServiceResponseBody {
fn map_body(self) -> ServiceResponse; fn map_body(self) -> ServiceResponse<BoxBody>;
} }
impl<B> MapServiceResponseBody for ServiceResponse<B> impl<B> MapServiceResponseBody for ServiceResponse<B>
where where
B: MessageBody + Unpin + 'static, B: MessageBody + 'static,
{ {
fn map_body(self) -> ServiceResponse { #[inline]
fn map_body(self) -> ServiceResponse<BoxBody> {
self.map_into_boxed_body() self.map_into_boxed_body()
} }
} }

View File

@ -106,7 +106,7 @@ mod tests {
header::{HeaderValue, CONTENT_TYPE}, header::{HeaderValue, CONTENT_TYPE},
StatusCode, StatusCode,
}, },
middleware::err_handlers::*, middleware::{err_handlers::*, Compat},
test::{self, TestRequest}, test::{self, TestRequest},
HttpResponse, HttpResponse,
}; };
@ -116,7 +116,8 @@ mod tests {
res.response_mut() res.response_mut()
.headers_mut() .headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res))
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
} }
#[actix_rt::test] #[actix_rt::test]
@ -125,7 +126,9 @@ mod tests {
ok(req.into_response(HttpResponse::InternalServerError().finish())) ok(req.into_response(HttpResponse::InternalServerError().finish()))
}; };
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Compat::new(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
);
let mw = Condition::new(true, mw) let mw = Condition::new(true, mw)
.new_transform(srv.into_service()) .new_transform(srv.into_service())
@ -141,7 +144,9 @@ mod tests {
ok(req.into_response(HttpResponse::InternalServerError().finish())) ok(req.into_response(HttpResponse::InternalServerError().finish()))
}; };
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Compat::new(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
);
let mw = Condition::new(false, mw) let mw = Condition::new(false, mw)
.new_transform(srv.into_service()) .new_transform(srv.into_service())

View File

@ -13,6 +13,7 @@ use futures_core::{future::LocalBoxFuture, ready};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use crate::{ use crate::{
body::EitherBody,
dev::{ServiceRequest, ServiceResponse}, dev::{ServiceRequest, ServiceResponse},
http::StatusCode, http::StatusCode,
Error, Result, Error, Result,
@ -21,10 +22,10 @@ use crate::{
/// Return type for [`ErrorHandlers`] custom handlers. /// Return type for [`ErrorHandlers`] custom handlers.
pub enum ErrorHandlerResponse<B> { pub enum ErrorHandlerResponse<B> {
/// Immediate HTTP response. /// Immediate HTTP response.
Response(ServiceResponse<B>), Response(ServiceResponse<EitherBody<B>>),
/// A future that resolves to an HTTP response. /// A future that resolves to an HTTP response.
Future(LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>), Future(LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>),
} }
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>; type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
@ -44,7 +45,8 @@ type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse
/// res.response_mut() /// res.response_mut()
/// .headers_mut() /// .headers_mut()
/// .insert(header::CONTENT_TYPE, header::HeaderValue::from_static("Error")); /// .insert(header::CONTENT_TYPE, header::HeaderValue::from_static("Error"));
/// Ok(ErrorHandlerResponse::Response(res)) ///
/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// } /// }
/// ///
/// let app = App::new() /// let app = App::new()
@ -66,7 +68,7 @@ type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
impl<B> Default for ErrorHandlers<B> { impl<B> Default for ErrorHandlers<B> {
fn default() -> Self { fn default() -> Self {
ErrorHandlers { ErrorHandlers {
handlers: Rc::new(AHashMap::default()), handlers: Default::default(),
} }
} }
} }
@ -95,7 +97,7 @@ where
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse<EitherBody<B>>;
type Error = Error; type Error = Error;
type Transform = ErrorHandlersMiddleware<S, B>; type Transform = ErrorHandlersMiddleware<S, B>;
type InitError = (); type InitError = ();
@ -119,7 +121,7 @@ where
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse<EitherBody<B>>;
type Error = Error; type Error = Error;
type Future = ErrorHandlersFuture<S::Future, B>; type Future = ErrorHandlersFuture<S::Future, B>;
@ -143,8 +145,8 @@ pin_project! {
fut: Fut, fut: Fut,
handlers: Handlers<B>, handlers: Handlers<B>,
}, },
HandlerFuture { ErrorHandlerFuture {
fut: LocalBoxFuture<'static, Fut::Output>, fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
}, },
} }
} }
@ -153,25 +155,29 @@ impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
where where
Fut: Future<Output = Result<ServiceResponse<B>, Error>>, Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
{ {
type Output = Fut::Output; type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() { match self.as_mut().project() {
ErrorHandlersProj::ServiceFuture { fut, handlers } => { ErrorHandlersProj::ServiceFuture { fut, handlers } => {
let res = ready!(fut.poll(cx))?; let res = ready!(fut.poll(cx))?;
match handlers.get(&res.status()) { match handlers.get(&res.status()) {
Some(handler) => match handler(res)? { Some(handler) => match handler(res)? {
ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
ErrorHandlerResponse::Future(fut) => { ErrorHandlerResponse::Future(fut) => {
self.as_mut() self.as_mut()
.set(ErrorHandlersFuture::HandlerFuture { fut }); .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
self.poll(cx) self.poll(cx)
} }
}, },
None => Poll::Ready(Ok(res)),
None => Poll::Ready(Ok(res.map_into_left_body())),
} }
} }
ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx),
ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
} }
} }
} }
@ -180,32 +186,33 @@ where
mod tests { mod tests {
use actix_service::IntoService; use actix_service::IntoService;
use actix_utils::future::ok; use actix_utils::future::ok;
use bytes::Bytes;
use futures_util::future::FutureExt as _; use futures_util::future::FutureExt as _;
use super::*; use super::*;
use crate::http::{ use crate::{
http::{
header::{HeaderValue, CONTENT_TYPE}, header::{HeaderValue, CONTENT_TYPE},
StatusCode, StatusCode,
},
test::{self, TestRequest},
}; };
use crate::test::{self, TestRequest};
use crate::HttpResponse;
#[actix_rt::test]
async fn add_header_error_handler() {
#[allow(clippy::unnecessary_wraps)] #[allow(clippy::unnecessary_wraps)]
fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> { fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut() res.response_mut()
.headers_mut() .headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res))
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
} }
#[actix_rt::test] let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
async fn test_handler() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = ErrorHandlers::new() let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service()) .new_transform(srv.into_service())
.await .await
.unwrap(); .unwrap();
@ -214,24 +221,25 @@ mod tests {
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
} }
#[actix_rt::test]
async fn add_header_error_handler_async() {
#[allow(clippy::unnecessary_wraps)] #[allow(clippy::unnecessary_wraps)]
fn render_500_async<B: 'static>( fn error_handler<B: 'static>(
mut res: ServiceResponse<B>, mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> { ) -> Result<ErrorHandlerResponse<B>> {
res.response_mut() res.response_mut()
.headers_mut() .headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Future(ok(res).boxed_local()))
Ok(ErrorHandlerResponse::Future(
ok(res.map_into_left_body()).boxed_local(),
))
} }
#[actix_rt::test] let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
async fn test_handler_async() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = ErrorHandlers::new() let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service()) .new_transform(srv.into_service())
.await .await
.unwrap(); .unwrap();
@ -239,4 +247,34 @@ mod tests {
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
} }
#[actix_rt::test]
async fn changes_body_type() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B: 'static>(
res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
let (req, res) = res.into_parts();
let res = res.set_body(Bytes::from("sorry, that's no bueno"));
let res = ServiceResponse::new(req, res)
.map_into_boxed_body()
.map_into_right_body();
Ok(ErrorHandlerResponse::Response(res))
}
let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
}
// TODO: test where error is thrown
} }

View File

@ -35,7 +35,7 @@ mod tests {
.wrap(Condition::new(true, DefaultHeaders::new())) .wrap(Condition::new(true, DefaultHeaders::new()))
.wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2")))
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res)) Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
})) }))
.wrap(Logger::default()) .wrap(Logger::default())
.wrap(NormalizePath::new(TrailingSlash::Trim)); .wrap(NormalizePath::new(TrailingSlash::Trim));
@ -44,7 +44,7 @@ mod tests {
.wrap(NormalizePath::new(TrailingSlash::Trim)) .wrap(NormalizePath::new(TrailingSlash::Trim))
.wrap(Logger::default()) .wrap(Logger::default())
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res)) Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
})) }))
.wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2")))
.wrap(Condition::new(true, DefaultHeaders::new())) .wrap(Condition::new(true, DefaultHeaders::new()))