1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-23 16:21:06 +01:00

allow error handler middleware to return different body type (#2515)

This commit is contained in:
Rob Ede 2021-12-14 21:17:50 +00:00 committed by GitHub
parent 05255c7f7c
commit dd4a372613
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 162 additions and 93 deletions

View File

@ -7,12 +7,15 @@
### Changed
* Align `DefaultHeader` method terminology, deprecating previous methods. [#2510]
* Response service types in `ErrorHandlers` middleware now use `ServiceResponse<EitherBody<B>>` to allow changing the body type. [#2515]
* Both variants in `ErrorHandlerResponse` now use `ServiceResponse<EitherBody<B>>`. [#2515]
### Removed
* Top-level `EitherExtractError` export. [#2510]
* Conversion implementations for `either` crate. [#2516]
[#2510]: https://github.com/actix/actix-web/pull/2510
[#2515]: https://github.com/actix/actix-web/pull/2515
[#2516]: https://github.com/actix/actix-web/pull/2516

View File

@ -2,22 +2,28 @@ use crate::ResourcePath;
#[allow(dead_code)]
const GEN_DELIMS: &[u8] = b":/?#[]@";
#[allow(dead_code)]
const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,";
#[allow(dead_code)]
const SUB_DELIMS: &[u8] = b"!$'()*,+?=;";
#[allow(dead_code)]
const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;";
#[allow(dead_code)]
const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ
1234567890
-._~";
const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz
ABCDEFGHIJKLMNOPQRSTUVWXYZ
1234567890
-._~
!$'()*,";
const QS: &[u8] = b"+&=;b";
#[inline]
@ -34,19 +40,20 @@ thread_local! {
static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+");
}
#[derive(Default, Clone, Debug)]
#[derive(Debug, Clone, Default)]
pub struct Url {
uri: http::Uri,
path: Option<String>,
}
impl Url {
#[inline]
pub fn new(uri: http::Uri) -> Url {
let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes()));
Url { uri, path }
}
#[inline]
pub fn with_quoter(uri: http::Uri, quoter: &Quoter) -> Url {
Url {
path: quoter.requote(uri.path().as_bytes()),
@ -54,15 +61,16 @@ impl Url {
}
}
#[inline]
pub fn uri(&self) -> &http::Uri {
&self.uri
}
#[inline]
pub fn path(&self) -> &str {
if let Some(ref s) = self.path {
s
} else {
self.uri.path()
match self.path {
Some(ref path) => path,
_ => self.uri.path(),
}
}
@ -86,6 +94,7 @@ impl ResourcePath for Url {
}
}
/// A quoter
pub struct Quoter {
safe_table: [u8; 16],
protected_table: [u8; 16],
@ -93,7 +102,7 @@ pub struct Quoter {
impl Quoter {
pub fn new(safe: &[u8], protected: &[u8]) -> Quoter {
let mut q = Quoter {
let mut quoter = Quoter {
safe_table: [0; 16],
protected_table: [0; 16],
};
@ -101,24 +110,24 @@ impl Quoter {
// prepare safe table
for i in 0..128 {
if ALLOWED.contains(&i) {
set_bit(&mut q.safe_table, i);
set_bit(&mut quoter.safe_table, i);
}
if QS.contains(&i) {
set_bit(&mut q.safe_table, i);
set_bit(&mut quoter.safe_table, i);
}
}
for ch in safe {
set_bit(&mut q.safe_table, *ch)
set_bit(&mut quoter.safe_table, *ch)
}
// prepare protected table
for ch in protected {
set_bit(&mut q.safe_table, *ch);
set_bit(&mut q.protected_table, *ch);
set_bit(&mut quoter.safe_table, *ch);
set_bit(&mut quoter.protected_table, *ch);
}
q
quoter
}
pub fn requote(&self, val: &[u8]) -> Option<String> {
@ -215,7 +224,7 @@ mod tests {
}
#[test]
fn test_parse_url() {
fn parse_url() {
let re = "/user/{id}/test";
let path = match_url(re, "/user/2345/test");
@ -231,24 +240,24 @@ mod tests {
}
#[test]
fn test_protected_chars() {
fn protected_chars() {
let encoded = percent_encode(PROTECTED);
let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded));
assert_eq!(path.get("id").unwrap(), &encoded);
}
#[test]
fn test_non_protecteed_ascii() {
let nonprotected_ascii = ('\u{0}'..='\u{7F}')
fn non_protected_ascii() {
let non_protected_ascii = ('\u{0}'..='\u{7F}')
.filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8)))
.collect::<String>();
let encoded = percent_encode(nonprotected_ascii.as_bytes());
let encoded = percent_encode(non_protected_ascii.as_bytes());
let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded));
assert_eq!(path.get("id").unwrap(), &nonprotected_ascii);
assert_eq!(path.get("id").unwrap(), &non_protected_ascii);
}
#[test]
fn test_valid_utf8_multibyte() {
fn valid_utf8_multibyte() {
let test = ('\u{FF00}'..='\u{FFFF}').collect::<String>();
let encoded = percent_encode(test.as_bytes());
let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded));
@ -256,7 +265,7 @@ mod tests {
}
#[test]
fn test_invalid_utf8() {
fn invalid_utf8() {
let invalid_utf8 = percent_encode((0x80..=0xff).collect::<Vec<_>>().as_slice());
let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap();
let path = Path::new(Url::new(uri));
@ -266,7 +275,7 @@ mod tests {
}
#[test]
fn test_from_hex() {
fn hex_encoding() {
let hex = b"0123456789abcdefABCDEF";
for i in 0..256 {

View File

@ -4,15 +4,25 @@
set -x
cargo test --lib --tests -p=actix-router --all-features
cargo test --lib --tests -p=actix-http --all-features
cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls
cargo test --lib --tests -p=actix-web-codegen --all-features
cargo test --lib --tests -p=awc --all-features
cargo test --lib --tests -p=actix-http-test --all-features
cargo test --lib --tests -p=actix-test --all-features
cargo test --lib --tests -p=actix-files
cargo test --lib --tests -p=actix-multipart --all-features
cargo test --lib --tests -p=actix-web-actors --all-features
EXIT=0
cargo test --workspace --doc
save_exit_code() {
eval $@
local CMD_EXIT=$?
[ "$CMD_EXIT" = "0" ] || EXIT=$CMD_EXIT
}
save_exit_code cargo test --lib --tests -p=actix-router --all-features
save_exit_code cargo test --lib --tests -p=actix-http --all-features
save_exit_code cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls
save_exit_code cargo test --lib --tests -p=actix-web-codegen --all-features
save_exit_code cargo test --lib --tests -p=awc --all-features
save_exit_code cargo test --lib --tests -p=actix-http-test --all-features
save_exit_code cargo test --lib --tests -p=actix-test --all-features
save_exit_code cargo test --lib --tests -p=actix-files
save_exit_code cargo test --lib --tests -p=actix-multipart --all-features
save_exit_code cargo test --lib --tests -p=actix-web-actors --all-features
save_exit_code cargo test --workspace --doc
exit $EXIT

View File

@ -6,12 +6,15 @@ use std::{
task::{Context, Poll},
};
use actix_http::body::MessageBody;
use actix_service::{Service, Transform};
use futures_core::{future::LocalBoxFuture, ready};
use pin_project_lite::pin_project;
use crate::{error::Error, service::ServiceResponse};
use crate::{
body::{BoxBody, MessageBody},
dev::{Service, Transform},
error::Error,
service::ServiceResponse,
};
/// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap),
/// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition).
@ -52,7 +55,7 @@ where
T::Response: MapServiceResponseBody,
T::Error: Into<Error>,
{
type Response = ServiceResponse;
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Transform = CompatMiddleware<T::Transform>;
type InitError = T::InitError;
@ -77,7 +80,7 @@ where
S::Response: MapServiceResponseBody,
S::Error: Into<Error>,
{
type Response = ServiceResponse;
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Future = CompatMiddlewareFuture<S::Future>;
@ -102,7 +105,7 @@ where
T: MapServiceResponseBody,
E: Into<Error>,
{
type Output = Result<ServiceResponse, Error>;
type Output = Result<ServiceResponse<BoxBody>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match ready!(self.project().fut.poll(cx)) {
@ -116,14 +119,15 @@ where
/// Convert `ServiceResponse`'s `ResponseBody<B>` generic type to `ResponseBody<Body>`.
pub trait MapServiceResponseBody {
fn map_body(self) -> ServiceResponse;
fn map_body(self) -> ServiceResponse<BoxBody>;
}
impl<B> MapServiceResponseBody for ServiceResponse<B>
where
B: MessageBody + Unpin + 'static,
B: MessageBody + 'static,
{
fn map_body(self) -> ServiceResponse {
#[inline]
fn map_body(self) -> ServiceResponse<BoxBody> {
self.map_into_boxed_body()
}
}

View File

@ -106,7 +106,7 @@ mod tests {
header::{HeaderValue, CONTENT_TYPE},
StatusCode,
},
middleware::err_handlers::*,
middleware::{err_handlers::*, Compat},
test::{self, TestRequest},
HttpResponse,
};
@ -116,7 +116,8 @@ mod tests {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res))
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
#[actix_rt::test]
@ -125,7 +126,9 @@ mod tests {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Compat::new(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
);
let mw = Condition::new(true, mw)
.new_transform(srv.into_service())
@ -141,7 +144,9 @@ mod tests {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Compat::new(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
);
let mw = Condition::new(false, mw)
.new_transform(srv.into_service())

View File

@ -13,6 +13,7 @@ use futures_core::{future::LocalBoxFuture, ready};
use pin_project_lite::pin_project;
use crate::{
body::EitherBody,
dev::{ServiceRequest, ServiceResponse},
http::StatusCode,
Error, Result,
@ -21,10 +22,10 @@ use crate::{
/// Return type for [`ErrorHandlers`] custom handlers.
pub enum ErrorHandlerResponse<B> {
/// Immediate HTTP response.
Response(ServiceResponse<B>),
Response(ServiceResponse<EitherBody<B>>),
/// A future that resolves to an HTTP response.
Future(LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>),
Future(LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>),
}
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
@ -44,7 +45,8 @@ type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse
/// res.response_mut()
/// .headers_mut()
/// .insert(header::CONTENT_TYPE, header::HeaderValue::from_static("Error"));
/// Ok(ErrorHandlerResponse::Response(res))
///
/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// }
///
/// let app = App::new()
@ -66,7 +68,7 @@ type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
impl<B> Default for ErrorHandlers<B> {
fn default() -> Self {
ErrorHandlers {
handlers: Rc::new(AHashMap::default()),
handlers: Default::default(),
}
}
}
@ -95,7 +97,7 @@ where
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = ErrorHandlersMiddleware<S, B>;
type InitError = ();
@ -119,7 +121,7 @@ where
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = ErrorHandlersFuture<S::Future, B>;
@ -143,8 +145,8 @@ pin_project! {
fut: Fut,
handlers: Handlers<B>,
},
HandlerFuture {
fut: LocalBoxFuture<'static, Fut::Output>,
ErrorHandlerFuture {
fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
},
}
}
@ -153,25 +155,29 @@ impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
where
Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
{
type Output = Fut::Output;
type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
ErrorHandlersProj::ServiceFuture { fut, handlers } => {
let res = ready!(fut.poll(cx))?;
match handlers.get(&res.status()) {
Some(handler) => match handler(res)? {
ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
ErrorHandlerResponse::Future(fut) => {
self.as_mut()
.set(ErrorHandlersFuture::HandlerFuture { fut });
.set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
self.poll(cx)
}
},
None => Poll::Ready(Ok(res)),
None => Poll::Ready(Ok(res.map_into_left_body())),
}
}
ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx),
ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
}
}
}
@ -180,32 +186,33 @@ where
mod tests {
use actix_service::IntoService;
use actix_utils::future::ok;
use bytes::Bytes;
use futures_util::future::FutureExt as _;
use super::*;
use crate::http::{
header::{HeaderValue, CONTENT_TYPE},
StatusCode,
use crate::{
http::{
header::{HeaderValue, CONTENT_TYPE},
StatusCode,
},
test::{self, TestRequest},
};
use crate::test::{self, TestRequest};
use crate::HttpResponse;
#[allow(clippy::unnecessary_wraps)]
fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res))
}
#[actix_rt::test]
async fn test_handler() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
async fn add_header_error_handler() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500)
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
@ -214,24 +221,25 @@ mod tests {
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[allow(clippy::unnecessary_wraps)]
fn render_500_async<B: 'static>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Future(ok(res).boxed_local()))
}
#[actix_rt::test]
async fn test_handler_async() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
async fn add_header_error_handler_async() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B: 'static>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Future(
ok(res.map_into_left_body()).boxed_local(),
))
}
let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async)
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
@ -239,4 +247,34 @@ mod tests {
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn changes_body_type() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B: 'static>(
res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
let (req, res) = res.into_parts();
let res = res.set_body(Bytes::from("sorry, that's no bueno"));
let res = ServiceResponse::new(req, res)
.map_into_boxed_body()
.map_into_right_body();
Ok(ErrorHandlerResponse::Response(res))
}
let srv = test::default_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
}
// TODO: test where error is thrown
}

View File

@ -35,7 +35,7 @@ mod tests {
.wrap(Condition::new(true, DefaultHeaders::new()))
.wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2")))
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res))
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}))
.wrap(Logger::default())
.wrap(NormalizePath::new(TrailingSlash::Trim));
@ -44,7 +44,7 @@ mod tests {
.wrap(NormalizePath::new(TrailingSlash::Trim))
.wrap(Logger::default())
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res))
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}))
.wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2")))
.wrap(Condition::new(true, DefaultHeaders::new()))