From 0c144054cba2fee55bad6183865b98f3d96a74a2 Mon Sep 17 00:00:00 2001 From: Ali MJ Al-Nasrawy Date: Tue, 8 Feb 2022 10:50:05 +0300 Subject: [PATCH] make `Condition` generic over body type (#2635) Co-authored-by: Rob Ede --- actix-web/CHANGES.md | 4 + actix-web/src/middleware/condition.rs | 120 ++++++++++++++++++-------- 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/actix-web/CHANGES.md b/actix-web/CHANGES.md index 37fe6ca6..78fd50c8 100644 --- a/actix-web/CHANGES.md +++ b/actix-web/CHANGES.md @@ -1,11 +1,15 @@ # Changes ## Unreleased - 2021-xx-xx +### Changed +- `middleware::Condition` gained a broader compatibility; `Compat` is needed in fewer cases. [#2635] + ### Added - Implement `Responder` for `Vec`. [#2625] - Re-export `KeepAlive` in `http` mod. [#2625] [#2625]: https://github.com/actix/actix-web/pull/2625 +[#2635]: https://github.com/actix/actix-web/pull/2635 ## 4.0.0-rc.2 - 2022-02-02 diff --git a/actix-web/src/middleware/condition.rs b/actix-web/src/middleware/condition.rs index 659f88bc..65f25a67 100644 --- a/actix-web/src/middleware/condition.rs +++ b/actix-web/src/middleware/condition.rs @@ -1,18 +1,22 @@ //! For middleware documentation, see [`Condition`]. -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; -use actix_service::{Service, Transform}; -use actix_utils::future::Either; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; use futures_util::future::FutureExt as _; +use pin_project_lite::pin_project; + +use crate::{ + body::EitherBody, + dev::{Service, ServiceResponse, Transform}, +}; /// Middleware for conditionally enabling other middleware. /// -/// The controlled middleware must not change the `Service` interfaces. This means you cannot -/// control such middlewares like `Logger` or `Compress` directly. See the [`Compat`](super::Compat) -/// middleware for a workaround. -/// /// # Examples /// ``` /// use actix_web::middleware::{Condition, NormalizePath}; @@ -36,16 +40,16 @@ impl Condition { } } -impl Transform for Condition +impl Transform for Condition where - S: Service + 'static, - T: Transform, + S: Service, Error = Err> + 'static, + T: Transform, Error = Err>, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, { - type Response = S::Response; - type Error = S::Error; + type Response = ServiceResponse>; + type Error = Err; type Transform = ConditionMiddleware; type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; @@ -69,14 +73,14 @@ pub enum ConditionMiddleware { Disable(D), } -impl Service for ConditionMiddleware +impl Service for ConditionMiddleware where - E: Service, - D: Service, + E: Service, Error = Err>, + D: Service, Error = Err>, { - type Response = E::Response; - type Error = E::Error; - type Future = Either; + type Response = ServiceResponse>; + type Error = Err; + type Future = ConditionMiddlewareFuture; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { match self { @@ -87,27 +91,59 @@ where fn call(&self, req: Req) -> Self::Future { match self { - ConditionMiddleware::Enable(service) => Either::left(service.call(req)), - ConditionMiddleware::Disable(service) => Either::right(service.call(req)), + ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Enabled { + fut: service.call(req), + }, + ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Disabled { + fut: service.call(req), + }, } } } +pin_project! { + #[doc(hidden)] + #[project = ConditionProj] + pub enum ConditionMiddlewareFuture { + Enabled { #[pin] fut: E, }, + Disabled { #[pin] fut: D, }, + } +} + +impl Future for ConditionMiddlewareFuture +where + E: Future, Err>>, + D: Future, Err>>, +{ + type Output = Result>, Err>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = match self.project() { + ConditionProj::Enabled { fut } => ready!(fut.poll(cx))?.map_into_left_body(), + ConditionProj::Disabled { fut } => ready!(fut.poll(cx))?.map_into_right_body(), + }; + + Poll::Ready(Ok(res)) + } +} + #[cfg(test)] mod tests { - use actix_service::IntoService; - use actix_utils::future::ok; + use actix_service::IntoService as _; use super::*; use crate::{ + body::BoxBody, dev::{ServiceRequest, ServiceResponse}, error::Result, http::{ header::{HeaderValue, CONTENT_TYPE}, StatusCode, }, - middleware::{err_handlers::*, Compat}, + middleware::{self, ErrorHandlerResponse, ErrorHandlers}, test::{self, TestRequest}, + web::Bytes, HttpResponse, }; @@ -120,40 +156,52 @@ mod tests { Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) } + #[test] + fn compat_with_builtin_middleware() { + let _ = Condition::new(true, middleware::Compat::noop()); + let _ = Condition::new(true, middleware::Logger::default()); + let _ = Condition::new(true, middleware::Compress::default()); + let _ = Condition::new(true, middleware::NormalizePath::trim()); + let _ = Condition::new(true, middleware::DefaultHeaders::new()); + let _ = Condition::new(true, middleware::ErrorHandlers::::new()); + let _ = Condition::new(true, middleware::ErrorHandlers::::new()); + } + #[actix_rt::test] async fn test_handler_enabled() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) + let srv = |req: ServiceRequest| async move { + let resp = HttpResponse::InternalServerError().message_body(String::new())?; + Ok(req.into_response(resp)) }; - let mw = Compat::new( - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), - ); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) .await .unwrap(); - let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + + let resp: ServiceResponse, String>> = + test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[actix_rt::test] async fn test_handler_disabled() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) + let srv = |req: ServiceRequest| async move { + let resp = HttpResponse::InternalServerError().message_body(String::new())?; + Ok(req.into_response(resp)) }; - let mw = Compat::new( - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), - ); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) .await .unwrap(); - let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + let resp: ServiceResponse, String>> = + test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE), None); } }