//! For middleware documentation, see [`Compat`]. use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use actix_http::body::{Body, MessageBody, ResponseBody}; use actix_service::{Service, Transform}; use futures_core::{future::LocalBoxFuture, ready}; use crate::{error::Error, service::ServiceResponse}; /// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap), /// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition). /// /// # Usage /// ```rust /// use actix_web::middleware::{Logger, Compat}; /// use actix_web::{App, web}; /// /// let logger = Logger::default(); /// /// // this would not compile because of incompatible body types /// // let app = App::new() /// // .service(web::scope("scoped").wrap(logger)); /// /// // by using this middleware we can use the logger on a scope /// let app = App::new() /// .service(web::scope("scoped").wrap(Compat::new(logger))); /// ``` pub struct Compat { transform: T, } impl Compat { /// Wrap a middleware to give it broader compatibility. pub fn new(middleware: T) -> Self { Self { transform: middleware, } } } impl Transform for Compat where S: Service, T: Transform, T::Future: 'static, T::Response: MapServiceResponseBody, Error: From, { type Response = ServiceResponse; type Error = Error; type Transform = CompatMiddleware; type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { let fut = self.transform.new_transform(service); Box::pin(async move { let service = fut.await?; Ok(CompatMiddleware { service }) }) } } pub struct CompatMiddleware { service: S, } impl Service for CompatMiddleware where S: Service, S::Response: MapServiceResponseBody, Error: From, { type Response = ServiceResponse; type Error = Error; type Future = CompatMiddlewareFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx).map_err(From::from) } fn call(&mut self, req: Req) -> Self::Future { let fut = self.service.call(req); CompatMiddlewareFuture { fut } } } #[pin_project::pin_project] pub struct CompatMiddlewareFuture { #[pin] fut: Fut, } impl Future for CompatMiddlewareFuture where Fut: Future>, T: MapServiceResponseBody, Error: From, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.project().fut.poll(cx))?; Poll::Ready(Ok(res.map_body())) } } /// Convert `ServiceResponse`'s `ResponseBody` generic type to `ResponseBody`. pub trait MapServiceResponseBody { fn map_body(self) -> ServiceResponse; } impl MapServiceResponseBody for ServiceResponse { fn map_body(self) -> ServiceResponse { self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) } } #[cfg(test)] mod tests { use super::*; use actix_service::IntoService; use crate::dev::ServiceRequest; use crate::http::StatusCode; use crate::middleware::{Compress, Condition, Logger}; use crate::test::{call_service, init_service, TestRequest}; use crate::{web, App, HttpResponse}; #[actix_rt::test] async fn test_scope_middleware() { let logger = Logger::default(); let compress = Compress::default(); let mut srv = init_service( App::new().service( web::scope("app") .wrap(Compat::new(logger)) .wrap(Compat::new(compress)) .service( web::resource("/test").route(web::get().to(HttpResponse::Ok)), ), ), ) .await; let req = TestRequest::with_uri("/app/test").to_request(); let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_resource_scope_middleware() { let logger = Logger::default(); let compress = Compress::default(); let mut srv = init_service( App::new().service( web::resource("app/test") .wrap(Compat::new(logger)) .wrap(Compat::new(compress)) .route(web::get().to(HttpResponse::Ok)), ), ) .await; let req = TestRequest::with_uri("/app/test").to_request(); let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_condition_scope_middleware() { let srv = |req: ServiceRequest| { Box::pin(async move { Ok(req.into_response(HttpResponse::InternalServerError().finish())) }) }; let logger = Logger::default(); let mut mw = Condition::new(true, Compat::new(logger)) .new_transform(srv.into_service()) .await .unwrap(); let resp = call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } }