//! For middleware documentation, see [`Condition`]. use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use futures_util::future::{Either, FutureExt, LocalBoxFuture}; /// Middleware for conditionally enabling other middleware. /// /// The controlled middleware must not change the `Service` interfaces. This means you cannot /// control such middlewares like `Logger` or `Compress` directly. See the [`Compat`](super::Compat) /// middleware for a workaround. /// /// # Usage /// ```rust /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// /// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok(); /// let app = App::new() /// .wrap(Condition::new(enable_normalize, NormalizePath::default())); /// ``` pub struct Condition { transformer: T, enable: bool, } impl Condition { pub fn new(enable: bool, transformer: T) -> Self { Self { transformer, enable, } } } impl Transform for Condition where S: Service + 'static, T: Transform, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, { type Response = S::Response; type Error = S::Error; type Transform = ConditionMiddleware; type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { let fut = self.transformer.new_transform(service); async move { let wrapped_svc = fut.await?; Ok(ConditionMiddleware::Enable(wrapped_svc)) } .boxed_local() } else { async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local() } } } pub enum ConditionMiddleware { Enable(E), Disable(D), } impl Service for ConditionMiddleware where E: Service, D: Service, { type Response = E::Response; type Error = E::Error; type Future = Either; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { match self { ConditionMiddleware::Enable(service) => service.poll_ready(cx), ConditionMiddleware::Disable(service) => service.poll_ready(cx), } } fn call(&self, req: Req) -> Self::Future { match self { ConditionMiddleware::Enable(service) => Either::Left(service.call(req)), ConditionMiddleware::Disable(service) => Either::Right(service.call(req)), } } } #[cfg(test)] mod tests { use actix_service::IntoService; use futures_util::future::ok; use super::*; use crate::{ dev::{ServiceRequest, ServiceResponse}, error::Result, http::{header::CONTENT_TYPE, HeaderValue, StatusCode}, middleware::err_handlers::*, test::{self, TestRequest}, HttpResponse, }; fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(ErrorHandlerResponse::Response(res)) } #[actix_rt::test] async fn test_handler_enabled() { let srv = |req: ServiceRequest| { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) .await .unwrap(); let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[actix_rt::test] async fn test_handler_disabled() { let srv = |req: ServiceRequest| { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) .await .unwrap(); let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE), None); } }