diff --git a/CHANGES.md b/CHANGES.md index 5f8f489fc..57304d083 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,4 +1,10 @@ # Changes +## not released yet + +### Added + +* Add `middleware::Conditon` that conditionally enables another middleware + ## [1.0.7] - 2019-08-29 diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs new file mode 100644 index 000000000..ddc5fdd42 --- /dev/null +++ b/src/middleware/condition.rs @@ -0,0 +1,143 @@ +//! `Middleware` for conditionally enables another middleware. +use actix_service::{Service, Transform}; +use futures::future::{ok, Either, FutureResult, Map}; +use futures::{Future, Poll}; + +/// `Middleware` for conditionally enables another middleware. +/// The controled middleware must not change the `Service` interfaces. +/// This means you cannot control such middlewares like `Logger` or `Compress`. +/// +/// ## Usage +/// +/// ```rust +/// use actix_web::middleware::{Condition, NormalizePath}; +/// use actix_web::App; +/// +/// fn main() { +/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let app = App::new() +/// .wrap(Condition::new(enable_normalize, NormalizePath)); +/// } +/// ``` +pub struct Condition { + trans: T, + enable: bool, +} + +impl Condition { + pub fn new(enable: bool, trans: T) -> Self { + Self { trans, enable } + } +} + +impl Transform for Condition +where + S: Service, + T: Transform, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type InitError = T::InitError; + type Transform = ConditionMiddleware; + type Future = Either< + Map Self::Transform>, + FutureResult, + >; + + fn new_transform(&self, service: S) -> Self::Future { + if self.enable { + let f = self + .trans + .new_transform(service) + .map(ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform); + Either::A(f) + } else { + Either::B(ok(ConditionMiddleware::Disable(service))) + } + } +} + +pub enum ConditionMiddleware { + Enable(E), + Disable(D), +} + +impl Service for ConditionMiddleware +where + E: Service, + D: Service, +{ + type Request = E::Request; + type Response = E::Response; + type Error = E::Error; + type Future = Either; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + use ConditionMiddleware::*; + match self { + Enable(service) => service.poll_ready(), + Disable(service) => service.poll_ready(), + } + } + + fn call(&mut self, req: E::Request) -> Self::Future { + use ConditionMiddleware::*; + match self { + Enable(service) => Either::A(service.call(req)), + Disable(service) => Either::B(service.call(req)), + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::IntoService; + + use super::*; + use crate::dev::{ServiceRequest, ServiceResponse}; + use crate::error::Result; + use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; + use crate::middleware::errhandlers::*; + use crate::test::{self, TestRequest}; + use crate::HttpResponse; + + fn render_500(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res)) + } + + #[test] + fn test_handler_enabled() { + let srv = |req: ServiceRequest| { + req.into_response(HttpResponse::InternalServerError().finish()) + }; + + let mw = + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + + let mut mw = + test::block_on(Condition::new(true, mw).new_transform(srv.into_service())) + .unwrap(); + let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + #[test] + fn test_handler_disabled() { + let srv = |req: ServiceRequest| { + req.into_response(HttpResponse::InternalServerError().finish()) + }; + + let mw = + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + + let mut mw = + test::block_on(Condition::new(false, mw).new_transform(srv.into_service())) + .unwrap(); + + let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 814993f0c..311d0ee99 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -6,7 +6,9 @@ mod defaultheaders; pub mod errhandlers; mod logger; mod normalize; +mod condition; pub use self::defaultheaders::DefaultHeaders; pub use self::logger::Logger; pub use self::normalize::NormalizePath; +pub use self::condition::Condition;