diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index ddc5fdd4..6603fc00 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -1,7 +1,8 @@ //! `Middleware` for conditionally enables another middleware. +use std::task::{Context, Poll}; + use actix_service::{Service, Transform}; -use futures::future::{ok, Either, FutureResult, Map}; -use futures::{Future, Poll}; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; /// `Middleware` for conditionally enables another middleware. /// The controled middleware must not change the `Service` interfaces. @@ -13,11 +14,11 @@ use futures::{Future, Poll}; /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// -/// fn main() { -/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); -/// let app = App::new() -/// .wrap(Condition::new(enable_normalize, NormalizePath)); -/// } +/// # fn main() { +/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let app = App::new() +/// .wrap(Condition::new(enable_normalize, NormalizePath)); +/// # } /// ``` pub struct Condition { trans: T, @@ -32,29 +33,31 @@ impl Condition { impl Transform for Condition where - S: Service, + S: Service + 'static, T: Transform, + T::Future: 'static, + T::InitError: 'static, + T::Transform: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = T::InitError; type Transform = ConditionMiddleware; - type Future = Either< - Map Self::Transform>, - FutureResult, - >; + type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { - let f = self - .trans - .new_transform(service) - .map(ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform); - Either::A(f) + let f = self.trans.new_transform(service).map(|res| { + res.map( + ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, + ) + }); + Either::Left(f) } else { - Either::B(ok(ConditionMiddleware::Disable(service))) + Either::Right(ok(ConditionMiddleware::Disable(service))) } + .boxed_local() } } @@ -73,19 +76,19 @@ where type Error = E::Error; type Future = Either; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { use ConditionMiddleware::*; match self { - Enable(service) => service.poll_ready(), - Disable(service) => service.poll_ready(), + Enable(service) => service.poll_ready(cx), + Disable(service) => service.poll_ready(cx), } } fn call(&mut self, req: E::Request) -> Self::Future { use ConditionMiddleware::*; match self { - Enable(service) => Either::A(service.call(req)), - Disable(service) => Either::B(service.call(req)), + Enable(service) => Either::Left(service.call(req)), + Disable(service) => Either::Right(service.call(req)), } } } @@ -99,7 +102,7 @@ mod tests { use crate::error::Result; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::middleware::errhandlers::*; - use crate::test::{self, TestRequest}; + use crate::test::{self, block_on, TestRequest}; use crate::HttpResponse; fn render_500(mut res: ServiceResponse) -> Result> { @@ -111,33 +114,44 @@ mod tests { #[test] fn test_handler_enabled() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(true, mw).new_transform(srv.into_service())) + let mut mw = Condition::new(true, mw) + .new_transform(srv.into_service()) + .await .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } + #[test] fn test_handler_disabled() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(false, mw).new_transform(srv.into_service())) + let mut mw = Condition::new(false, mw) + .new_transform(srv.into_service()) + .await .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE), None); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + }) } } diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 5f73d4d7..c8a70285 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -1,9 +1,9 @@ //! Custom handlers service for responses. use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures::future::{err, ok, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{err, ok, Either, Future, FutureExt, LocalBoxFuture, Ready}; use hashbrown::HashMap; use crate::dev::{ServiceRequest, ServiceResponse}; @@ -15,7 +15,7 @@ pub enum ErrorHandlerResponse { /// New http response got generated Response(ServiceResponse), /// Result is a future that resolves to a new http response - Future(Box, Error = Error>>), + Future(LocalBoxFuture<'static, Result, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; @@ -39,17 +39,17 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result { handlers: Rc>>>, @@ -92,7 +92,7 @@ where type Error = Error; type InitError = (); type Transform = ErrorHandlersMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(ErrorHandlersMiddleware { @@ -117,26 +117,30 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); + let fut = self.service.call(req); + + async move { + let res = fut.await?; - Box::new(self.service.call(req).and_then(move |res| { if let Some(handler) = handlers.get(&res.status()) { match handler(res) { - Ok(ErrorHandlerResponse::Response(res)) => Either::A(ok(res)), - Ok(ErrorHandlerResponse::Future(fut)) => Either::B(fut), - Err(e) => Either::A(err(e)), + Ok(ErrorHandlerResponse::Response(res)) => Ok(res), + Ok(ErrorHandlerResponse::Future(fut)) => fut.await, + Err(e) => Err(e), } } else { - Either::A(ok(res)) + Ok(res) } - })) + } + .boxed_local() } } @@ -147,7 +151,7 @@ mod tests { use super::*; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{self, TestRequest}; + use crate::test::{self, block_on, TestRequest}; use crate::HttpResponse; fn render_500(mut res: ServiceResponse) -> Result> { @@ -159,19 +163,22 @@ mod tests { #[test] fn test_handler() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mut mw = test::block_on( - ErrorHandlers::new() + let mut mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) - .new_transform(srv.into_service()), - ) - .unwrap(); + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } fn render_500_async( @@ -180,23 +187,26 @@ mod tests { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Future(Box::new(ok(res)))) + Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) } #[test] fn test_handler_async() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mut mw = test::block_on( - ErrorHandlers::new() + let mut mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) - .new_transform(srv.into_service()), - ) - .unwrap(); + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 30acad15..84e0758b 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -2,13 +2,13 @@ mod compress; pub use self::compress::{BodyEncoding, Compress}; -//mod condition; +mod condition; mod defaultheaders; -//pub mod errhandlers; +pub mod errhandlers; mod logger; -//mod normalize; +mod normalize; -//pub use self::condition::Condition; +pub use self::condition::Condition; pub use self::defaultheaders::DefaultHeaders; pub use self::logger::Logger; -//pub use self::normalize::NormalizePath; +pub use self::normalize::NormalizePath; diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 9cfbefb3..b7eb1384 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,9 +1,10 @@ //! `Middleware` to normalize request's URI +use std::task::{Context, Poll}; use actix_http::http::{HttpTryFrom, PathAndQuery, Uri}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures::future::{self, FutureResult}; +use futures::future::{ok, Ready}; use regex::Regex; use crate::service::{ServiceRequest, ServiceResponse}; @@ -19,15 +20,15 @@ use crate::Error; /// ```rust /// use actix_web::{web, http, middleware, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new() -/// .wrap(middleware::NormalizePath) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # fn main() { +/// let app = App::new() +/// .wrap(middleware::NormalizePath) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// # } /// ``` pub struct NormalizePath; @@ -42,10 +43,10 @@ where type Error = Error; type InitError = (); type Transform = NormalizePathNormalization; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - future::ok(NormalizePathNormalization { + ok(NormalizePathNormalization { service, merge_slash: Regex::new("//+").unwrap(), }) @@ -67,8 +68,8 @@ where type Error = Error; type Future = S::Future; - fn poll_ready(&mut self) -> futures::Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -109,46 +110,57 @@ mod tests { #[test] fn test_wrap() { - let mut app = init_service( - App::new() - .wrap(NormalizePath::default()) - .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), - ); + block_on(async { + let mut app = init_service( + App::new() + .wrap(NormalizePath::default()) + .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), + ) + .await; - let req = TestRequest::with_uri("/v1//something////").to_request(); - let res = call_service(&mut app, req); - assert!(res.status().is_success()); + let req = TestRequest::with_uri("/v1//something////").to_request(); + let res = call_service(&mut app, req).await; + assert!(res.status().is_success()); + }) } #[test] fn test_in_place_normalization() { - let srv = |req: ServiceRequest| { - assert_eq!("/v1/something/", req.path()); - req.into_response(HttpResponse::Ok().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + assert_eq!("/v1/something/", req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); - let req = TestRequest::with_uri("/v1//something////").to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); - assert!(res.status().is_success()); + let req = TestRequest::with_uri("/v1//something////").to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + }) } #[test] fn should_normalize_nothing() { - const URI: &str = "/v1/something/"; + block_on(async { + const URI: &str = "/v1/something/"; - let srv = |req: ServiceRequest| { - assert_eq!(URI, req.path()); - req.into_response(HttpResponse::Ok().finish()) - }; + let srv = |req: ServiceRequest| { + assert_eq!(URI, req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); - let req = TestRequest::with_uri(URI).to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); - assert!(res.status().is_success()); + let req = TestRequest::with_uri(URI).to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + }) } }