From 5b06f2bee5663ed51cfb964e25360f7d10390515 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 23 Mar 2019 21:29:16 -0700 Subject: [PATCH] port cors middleware --- README.md | 23 +- src/middleware/cors.rs | 1173 ++++++++++++++---------------- src/middleware/defaultheaders.rs | 7 +- src/middleware/mod.rs | 1 + src/test.rs | 27 +- 5 files changed, 582 insertions(+), 649 deletions(-) diff --git a/README.md b/README.md index c7e195de9..ce9efbb71 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust. -* Supported *HTTP/1.x* and [*HTTP/2.0*](https://actix.rs/docs/http2/) protocols +* Supported *HTTP/1.x* and *HTTP/2.0* protocols * Streaming and pipelining * Keep-alive and slow requests handling * Client/server [WebSockets](https://actix.rs/docs/websockets/) support @@ -13,33 +13,33 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust. * SSL support with OpenSSL or `native-tls` * Middlewares ([Logger, Session, CORS, CSRF, etc](https://actix.rs/docs/middleware/)) * Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) -* Built on top of [Actix actor framework](https://github.com/actix/actix) +* Supports [Actix actor framework](https://github.com/actix/actix) * Experimental [Async/Await](https://github.com/mehcode/actix-web-async-await) support. ## Documentation & community resources * [User Guide](https://actix.rs/docs/) * [API Documentation (Development)](https://actix.rs/actix-web/actix_web/) -* [API Documentation (Releases)](https://actix.rs/api/actix-web/stable/actix_web/) +* [API Documentation (Releases)](https://docs.rs/actix-web/) * [Chat on gitter](https://gitter.im/actix/actix) * Cargo package: [actix-web](https://crates.io/crates/actix-web) -* Minimum supported Rust version: 1.31 or later +* Minimum supported Rust version: 1.32 or later ## Example ```rust -extern crate actix_web; -use actix_web::{http, server, App, Path, Responder}; +use actix_web::{web, App, HttpServer, Responder}; -fn index(info: Path<(u32, String)>) -> impl Responder { +fn index(info: web::Path<(u32, String)>) -> impl Responder { format!("Hello {}! id:{}", info.1, info.0) } -fn main() { - server::new( +fn main() -> std::io::Result<()> { + HttpServer::new( || App::new() - .route("/{id}/{name}/index.html", http::Method::GET, index)) - .bind("127.0.0.1:8080").unwrap() + .service(web::resource("/{id}/{name}/index.html") + .route(web::get().to(index))) + .bind("127.0.0.1:8080")? .run(); } ``` @@ -48,7 +48,6 @@ fn main() { * [Basics](https://github.com/actix/examples/tree/master/basics/) * [Stateful](https://github.com/actix/examples/tree/master/state/) -* [Protobuf support](https://github.com/actix/examples/tree/master/protobuf/) * [Multipart streams](https://github.com/actix/examples/tree/master/multipart/) * [Simple websocket](https://github.com/actix/examples/tree/master/websocket/) * [Tera](https://github.com/actix/examples/tree/master/template_tera/) / diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 80ee5b193..8f33d69b0 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -1,45 +1,36 @@ //! Cross-origin resource sharing (CORS) for Actix applications //! //! CORS middleware could be used with application and with resource. -//! First you need to construct CORS middleware instance. -//! -//! To construct a cors: -//! -//! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. -//! 2. Use any of the builder methods to set fields in the backend. -//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the -//! constructed backend. -//! -//! Cors middleware could be used as parameter for `App::middleware()` or -//! `Resource::middleware()` methods. But you have to use -//! `Cors::for_app()` method to support *preflight* OPTIONS request. -//! +//! Cors middleware could be used as parameter for `App::middleware()`, +//! `Resource::middleware()` or `Scope::middleware()` methods. //! //! # Example //! //! ```rust -//! # extern crate actix_web; //! use actix_web::middleware::cors::Cors; -//! use actix_web::{http, App, HttpRequest, HttpResponse}; +//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! -//! fn index(mut req: HttpRequest) -> &'static str { +//! fn index(req: HttpRequest) -> &'static str { //! "Hello world" //! } //! -//! fn main() { -//! let app = App::new().configure(|app| { -//! Cors::for_app(app) // <- Construct CORS middleware builder -//! .allowed_origin("https://www.rust-lang.org/") -//! .allowed_methods(vec!["GET", "POST"]) -//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) -//! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600) -//! .resource("/index.html", |r| { -//! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); -//! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); -//! }) -//! .register() -//! }); +//! fn main() -> std::io::Result<()> { +//! HttpServer::new(|| App::new() +//! .middleware( +//! Cors::new() // <- Construct CORS middleware builder +//! .allowed_origin("https://www.rust-lang.org/") +//! .allowed_methods(vec!["GET", "POST"]) +//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) +//! .allowed_header(http::header::CONTENT_TYPE) +//! .max_age(3600)) +//! .service( +//! web::resource("/index.html") +//! .route(web::get().to(index)) +//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +//! )) +//! .bind("127.0.0.1:8080")?; +//! +//! Ok(()) //! } //! ``` //! In this example custom *CORS* middleware get registered for "/index.html" @@ -50,68 +41,67 @@ use std::collections::HashSet; use std::iter::FromIterator; use std::rc::Rc; -use http::header::{self, HeaderName, HeaderValue}; -use http::{self, HttpTryFrom, Method, StatusCode, Uri}; +use actix_service::{IntoTransform, Service, Transform}; +use derive_more::Display; +use futures::future::{ok, Either, Future, FutureResult}; +use futures::Poll; -use application::App; -use error::{ResponseError, Result}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Response, Started}; -use resource::Resource; -use router::ResourceDef; -use server::Request; +use crate::dev::{Head, RequestHead}; +use crate::error::{ResponseError, Result}; +use crate::http::header::{self, HeaderName, HeaderValue}; +use crate::http::{self, HttpTryFrom, Method, StatusCode, Uri}; +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{HttpMessage, HttpResponse}; /// A set of errors that can occur during processing CORS -#[derive(Debug, Fail)] +#[derive(Debug, Display)] pub enum CorsError { /// The HTTP request header `Origin` is required but was not provided - #[fail( - display = "The HTTP request header `Origin` is required but was not provided" + #[display( + fmt = "The HTTP request header `Origin` is required but was not provided" )] MissingOrigin, /// The HTTP request header `Origin` could not be parsed correctly. - #[fail(display = "The HTTP request header `Origin` could not be parsed correctly.")] + #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")] BadOrigin, /// The request header `Access-Control-Request-Method` is required but is /// missing - #[fail( - display = "The request header `Access-Control-Request-Method` is required but is missing" + #[display( + fmt = "The request header `Access-Control-Request-Method` is required but is missing" )] MissingRequestMethod, /// The request header `Access-Control-Request-Method` has an invalid value - #[fail( - display = "The request header `Access-Control-Request-Method` has an invalid value" + #[display( + fmt = "The request header `Access-Control-Request-Method` has an invalid value" )] BadRequestMethod, /// The request header `Access-Control-Request-Headers` has an invalid /// value - #[fail( - display = "The request header `Access-Control-Request-Headers` has an invalid value" + #[display( + fmt = "The request header `Access-Control-Request-Headers` has an invalid value" )] BadRequestHeaders, /// The request header `Access-Control-Request-Headers` is required but is /// missing. - #[fail( - display = "The request header `Access-Control-Request-Headers` is required but is + #[display( + fmt = "The request header `Access-Control-Request-Headers` is required but is missing" )] MissingRequestHeaders, /// Origin is not allowed to make this request - #[fail(display = "Origin is not allowed to make this request")] + #[display(fmt = "Origin is not allowed to make this request")] OriginNotAllowed, /// Requested method is not allowed - #[fail(display = "Requested method is not allowed")] + #[display(fmt = "Requested method is not allowed")] MethodNotAllowed, /// One or more headers requested are not allowed - #[fail(display = "One or more headers requested are not allowed")] + #[display(fmt = "One or more headers requested are not allowed")] HeadersNotAllowed, } impl ResponseError for CorsError { fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self)) + HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) } } @@ -156,327 +146,6 @@ impl AllOrSome { } } -/// `Middleware` for Cross-origin resource sharing support -/// -/// The Cors struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -#[derive(Clone)] -pub struct Cors { - inner: Rc, -} - -struct Inner { - methods: HashSet, - origins: AllOrSome>, - origins_str: Option, - headers: AllOrSome>, - expose_hdrs: Option, - max_age: Option, - preflight: bool, - send_wildcard: bool, - supports_credentials: bool, - vary_header: bool, -} - -impl Default for Cors { - fn default() -> Cors { - let inner = Inner { - origins: AllOrSome::default(), - origins_str: None, - methods: HashSet::from_iter( - vec![ - Method::GET, - Method::HEAD, - Method::POST, - Method::OPTIONS, - Method::PUT, - Method::PATCH, - Method::DELETE, - ].into_iter(), - ), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }; - Cors { - inner: Rc::new(inner), - } - } -} - -impl Cors { - /// Build a new CORS middleware instance - pub fn build() -> CorsBuilder<()> { - CorsBuilder { - cors: Some(Inner { - origins: AllOrSome::All, - origins_str: None, - methods: HashSet::new(), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }), - methods: false, - error: None, - expose_hdrs: HashSet::new(), - resources: Vec::new(), - app: None, - } - } - - /// Create CorsBuilder for a specified application. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::middleware::cors::Cors; - /// use actix_web::{http, App, HttpResponse}; - /// - /// fn main() { - /// let app = App::new().configure( - /// |app| { - /// Cors::for_app(app) // <- Construct CORS builder - /// .allowed_origin("https://www.rust-lang.org/") - /// .resource("/resource", |r| { // register resource - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// }) - /// .register() - /// }, // construct CORS and return application instance - /// ); - /// } - /// ``` - pub fn for_app(app: App) -> CorsBuilder { - CorsBuilder { - cors: Some(Inner { - origins: AllOrSome::All, - origins_str: None, - methods: HashSet::new(), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }), - methods: false, - error: None, - expose_hdrs: HashSet::new(), - resources: Vec::new(), - app: Some(app), - } - } - - /// This method register cors middleware with resource and - /// adds route for *OPTIONS* preflight requests. - /// - /// It is possible to register *Cors* middleware with - /// `Resource::middleware()` method, but in that case *Cors* - /// middleware wont be able to handle *OPTIONS* requests. - pub fn register(self, resource: &mut Resource) { - resource - .method(Method::OPTIONS) - .h(|_: &_| HttpResponse::Ok()); - resource.middleware(self); - } - - fn validate_origin(&self, req: &Request) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ORIGIN) { - if let Ok(origin) = hdr.to_str() { - return match self.inner.origins { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => allowed_origins - .get(origin) - .and_then(|_| Some(())) - .ok_or_else(|| CorsError::OriginNotAllowed), - }; - } - Err(CorsError::BadOrigin) - } else { - return match self.inner.origins { - AllOrSome::All => Ok(()), - _ => Err(CorsError::MissingOrigin), - }; - } - } - - fn access_control_allow_origin(&self, req: &Request) -> Option { - match self.inner.origins { - AllOrSome::All => { - if self.inner.send_wildcard { - Some(HeaderValue::from_static("*")) - } else if let Some(origin) = req.headers().get(header::ORIGIN) { - Some(origin.clone()) - } else { - None - } - } - AllOrSome::Some(ref origins) => { - if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| { - match o.to_str() { - Ok(os) => origins.contains(os), - _ => false - } - }) { - Some(origin.clone()) - } else { - Some(self.inner.origins_str.as_ref().unwrap().clone()) - } - } - } - } - - fn validate_allowed_method(&self, req: &Request) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { - if let Ok(meth) = hdr.to_str() { - if let Ok(method) = Method::try_from(meth) { - return self - .inner - .methods - .get(&method) - .and_then(|_| Some(())) - .ok_or_else(|| CorsError::MethodNotAllowed); - } - } - Err(CorsError::BadRequestMethod) - } else { - Err(CorsError::MissingRequestMethod) - } - } - - fn validate_allowed_headers(&self, req: &Request) -> Result<(), CorsError> { - match self.inner.headers { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_headers) => { - if let Some(hdr) = - req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) - { - if let Ok(headers) = hdr.to_str() { - let mut hdrs = HashSet::new(); - for hdr in headers.split(',') { - match HeaderName::try_from(hdr.trim()) { - Ok(hdr) => hdrs.insert(hdr), - Err(_) => return Err(CorsError::BadRequestHeaders), - }; - } - - if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { - return Err(CorsError::HeadersNotAllowed); - } - return Ok(()); - } - Err(CorsError::BadRequestHeaders) - } else { - Err(CorsError::MissingRequestHeaders) - } - } - } - } -} - -impl Middleware for Cors { - fn start(&self, req: &HttpRequest) -> Result { - if self.inner.preflight && Method::OPTIONS == *req.method() { - self.validate_origin(req)?; - self.validate_allowed_method(&req)?; - self.validate_allowed_headers(&req)?; - - // allowed headers - let headers = if let Some(headers) = self.inner.headers.as_ref() { - Some( - HeaderValue::try_from( - &headers - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ).unwrap(), - ) - } else if let Some(hdr) = - req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) - { - Some(hdr.clone()) - } else { - None - }; - - Ok(Started::Response( - HttpResponse::Ok() - .if_some(self.inner.max_age.as_ref(), |max_age, resp| { - let _ = resp.header( - header::ACCESS_CONTROL_MAX_AGE, - format!("{}", max_age).as_str(), - ); - }).if_some(headers, |headers, resp| { - let _ = - resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }).if_some(self.access_control_allow_origin(&req), |origin, resp| { - let _ = - resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }).if_true(self.inner.supports_credentials, |resp| { - resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }).header( - header::ACCESS_CONTROL_ALLOW_METHODS, - &self - .inner - .methods - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ).finish(), - )) - } else { - // Only check requests with a origin header. - if req.headers().contains_key(header::ORIGIN) { - self.validate_origin(req)?; - } - - Ok(Started::Done) - } - } - - fn response( - &self, req: &HttpRequest, mut resp: HttpResponse, - ) -> Result { - - if let Some(origin) = self.access_control_allow_origin(req) { - resp.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - }; - - if let Some(ref expose) = self.inner.expose_hdrs { - resp.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap(), - ); - } - if self.inner.supports_credentials { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if self.inner.vary_header { - let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { - let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); - val.extend(hdr.as_bytes()); - val.extend(b", Origin"); - HeaderValue::try_from(&val[..]).unwrap() - } else { - HeaderValue::from_static("Origin") - }; - resp.headers_mut().insert(header::VARY, value); - } - Ok(Response::Done(resp)) - } -} - /// Structure that follows the builder pattern for building `Cors` middleware /// structs. /// @@ -490,40 +159,77 @@ impl Middleware for Cors { /// # Example /// /// ```rust -/// # extern crate http; -/// # extern crate actix_web; +/// use actix_web::http::header; /// use actix_web::middleware::cors; -/// use http::header; /// /// # fn main() { -/// let cors = cors::Cors::build() +/// let cors = cors::Cors::new() /// .allowed_origin("https://www.rust-lang.org/") /// .allowed_methods(vec!["GET", "POST"]) /// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) /// .allowed_header(header::CONTENT_TYPE) -/// .max_age(3600) -/// .finish(); +/// .max_age(3600); /// # } /// ``` -pub struct CorsBuilder { +pub struct Cors { cors: Option, methods: bool, error: Option, expose_hdrs: HashSet, - resources: Vec>, - app: Option>, } -fn cors<'a>( - parts: &'a mut Option, err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; +impl Cors { + /// Build a new CORS middleware instance + pub fn new() -> Cors { + Cors { + cors: Some(Inner { + origins: AllOrSome::All, + origins_str: None, + methods: HashSet::new(), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }), + methods: false, + error: None, + expose_hdrs: HashSet::new(), + } + } + + /// Build a new CORS default middleware + pub fn default() -> CorsFactory { + let inner = Inner { + origins: AllOrSome::default(), + origins_str: None, + methods: HashSet::from_iter( + vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ] + .into_iter(), + ), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }; + CorsFactory { + inner: Rc::new(inner), + } } - parts.as_mut() -} -impl CorsBuilder { /// Add an origin that are allowed to make requests. /// Will be verified against the `Origin` request header. /// @@ -541,7 +247,7 @@ impl CorsBuilder { /// Defaults to `All`. /// /// Builder panics if supplied origin is not valid uri. - pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder { + pub fn allowed_origin(mut self, origin: &str) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { match Uri::try_from(origin) { Ok(_) => { @@ -567,7 +273,7 @@ impl CorsBuilder { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` - pub fn allowed_methods(&mut self, methods: U) -> &mut CorsBuilder + pub fn allowed_methods(mut self, methods: U) -> Cors where U: IntoIterator, Method: HttpTryFrom, @@ -590,7 +296,7 @@ impl CorsBuilder { } /// Set an allowed header - pub fn allowed_header(&mut self, header: H) -> &mut CorsBuilder + pub fn allowed_header(mut self, header: H) -> Cors where HeaderName: HttpTryFrom, { @@ -621,7 +327,7 @@ impl CorsBuilder { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `All`. - pub fn allowed_headers(&mut self, headers: U) -> &mut CorsBuilder + pub fn allowed_headers(mut self, headers: U) -> Cors where U: IntoIterator, HeaderName: HttpTryFrom, @@ -655,7 +361,7 @@ impl CorsBuilder { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// This defaults to an empty set. - pub fn expose_headers(&mut self, headers: U) -> &mut CorsBuilder + pub fn expose_headers(mut self, headers: U) -> Cors where U: IntoIterator, HeaderName: HttpTryFrom, @@ -678,7 +384,7 @@ impl CorsBuilder { /// This value is set as the `Access-Control-Max-Age` header. /// /// This defaults to `None` (unset). - pub fn max_age(&mut self, max_age: usize) -> &mut CorsBuilder { + pub fn max_age(mut self, max_age: usize) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.max_age = Some(max_age) } @@ -700,7 +406,7 @@ impl CorsBuilder { /// CredentialsWithWildcardOrigin` error during actix launch or runtime. /// /// Defaults to `false`. - pub fn send_wildcard(&mut self) -> &mut CorsBuilder { + pub fn send_wildcard(mut self) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.send_wildcard = true } @@ -720,7 +426,7 @@ impl CorsBuilder { /// /// Builder panics if credentials are allowed, but the Origin is set to "*". /// This is not allowed by W3C - pub fn supports_credentials(&mut self) -> &mut CorsBuilder { + pub fn supports_credentials(mut self) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.supports_credentials = true } @@ -738,7 +444,7 @@ impl CorsBuilder { /// caches that the CORS headers are dynamic, and cannot be cached. /// /// By default `vary` header support is enabled. - pub fn disable_vary_header(&mut self) -> &mut CorsBuilder { + pub fn disable_vary_header(mut self) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.vary_header = false } @@ -751,57 +457,32 @@ impl CorsBuilder { /// This is useful application level middleware. /// /// By default *preflight* support is enabled. - pub fn disable_preflight(&mut self) -> &mut CorsBuilder { + pub fn disable_preflight(mut self) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.preflight = false } self } +} - /// Configure resource for a specific path. - /// - /// This is similar to a `App::resource()` method. Except, cors middleware - /// get registered for the resource. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::middleware::cors::Cors; - /// use actix_web::{http, App, HttpResponse}; - /// - /// fn main() { - /// let app = App::new().configure( - /// |app| { - /// Cors::for_app(app) // <- Construct CORS builder - /// .allowed_origin("https://www.rust-lang.org/") - /// .allowed_methods(vec!["GET", "POST"]) - /// .allowed_header(http::header::CONTENT_TYPE) - /// .max_age(3600) - /// .resource("/resource1", |r| { // register resource - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// }) - /// .resource("/resource2", |r| { // register another resource - /// r.method(http::Method::HEAD) - /// .f(|_| HttpResponse::MethodNotAllowed()); - /// }) - /// .register() - /// }, // construct CORS and return application instance - /// ); - /// } - /// ``` - pub fn resource(&mut self, path: &str, f: F) -> &mut CorsBuilder - where - F: FnOnce(&mut Resource) -> R + 'static, - { - // add resource handler - let mut resource = Resource::new(ResourceDef::new(path)); - f(&mut resource); - - self.resources.push(resource); - self +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; } + parts.as_mut() +} - fn construct(&mut self) -> Cors { - if !self.methods { +impl IntoTransform for Cors +where + S: Service, Response = ServiceResponse> + 'static, + P: 'static, + B: 'static, +{ + fn into_transform(self) -> CorsFactory { + let mut slf = if !self.methods { self.allowed_methods(vec![ Method::GET, Method::HEAD, @@ -810,14 +491,16 @@ impl CorsBuilder { Method::PUT, Method::PATCH, Method::DELETE, - ]); - } + ]) + } else { + self + }; - if let Some(e) = self.error.take() { + if let Some(e) = slf.error.take() { panic!("{}", e); } - let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); + let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder"); if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { panic!("Credentials are allowed, but the Origin is set to \"*\""); @@ -830,152 +513,383 @@ impl CorsBuilder { cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); } - if !self.expose_hdrs.is_empty() { + if !slf.expose_hdrs.is_empty() { cors.expose_hdrs = Some( - self.expose_hdrs + slf.expose_hdrs .iter() .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] .to_owned(), ); } - Cors { + + CorsFactory { inner: Rc::new(cors), } } +} - /// Finishes building and returns the built `Cors` instance. - /// - /// This method panics in case of any configuration error. - pub fn finish(&mut self) -> Cors { - if !self.resources.is_empty() { - panic!( - "CorsBuilder::resource() was used, - to construct CORS `.register(app)` method should be used" - ); +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +pub struct CorsFactory { + inner: Rc, +} + +impl Transform for CorsFactory +where + S: Service, Response = ServiceResponse>, + S::Future: 'static, + S::Error: 'static, + P: 'static, + B: 'static, +{ + type Request = ServiceRequest

; + type Response = ServiceResponse; + type Error = S::Error; + type InitError = (); + type Transform = CorsMiddleware; + type Future = FutureResult; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CorsMiddleware { + service, + inner: self.inner.clone(), + }) + } +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +#[derive(Clone)] +pub struct CorsMiddleware { + service: S, + inner: Rc, +} + +struct Inner { + methods: HashSet, + origins: AllOrSome>, + origins_str: Option, + headers: AllOrSome>, + expose_hdrs: Option, + max_age: Option, + preflight: bool, + send_wildcard: bool, + supports_credentials: bool, + vary_header: bool, +} + +impl Inner { + fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(header::ORIGIN) { + if let Ok(origin) = hdr.to_str() { + return match self.origins { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::OriginNotAllowed), + }; + } + Err(CorsError::BadOrigin) + } else { + return match self.origins { + AllOrSome::All => Ok(()), + _ => Err(CorsError::MissingOrigin), + }; } - self.construct() } - /// Finishes building Cors middleware and register middleware for - /// application - /// - /// This method panics in case of any configuration error or if non of - /// resources are registered. - pub fn register(&mut self) -> App { - if self.resources.is_empty() { - panic!("No resources are registered."); + fn access_control_allow_origin(&self, req: &RequestHead) -> Option { + match self.origins { + AllOrSome::All => { + if self.send_wildcard { + Some(HeaderValue::from_static("*")) + } else if let Some(origin) = req.headers().get(header::ORIGIN) { + Some(origin.clone()) + } else { + None + } + } + AllOrSome::Some(ref origins) => { + if let Some(origin) = + req.headers() + .get(header::ORIGIN) + .filter(|o| match o.to_str() { + Ok(os) => origins.contains(os), + _ => false, + }) + { + Some(origin.clone()) + } else { + Some(self.origins_str.as_ref().unwrap().clone()) + } + } } + } - let cors = self.construct(); - let mut app = self - .app - .take() - .expect("CorsBuilder has to be constructed with Cors::for_app(app)"); - - // register resources - for mut resource in self.resources.drain(..) { - cors.clone().register(&mut resource); - app.register_resource(resource); + fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { + if let Ok(meth) = hdr.to_str() { + if let Ok(method) = Method::try_from(meth) { + return self + .methods + .get(&method) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::MethodNotAllowed); + } + } + Err(CorsError::BadRequestMethod) + } else { + Err(CorsError::MissingRequestMethod) } + } - app + fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> { + match self.headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if let Some(hdr) = + req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) + { + if let Ok(headers) = hdr.to_str() { + let mut hdrs = HashSet::new(); + for hdr in headers.split(',') { + match HeaderName::try_from(hdr.trim()) { + Ok(hdr) => hdrs.insert(hdr), + Err(_) => return Err(CorsError::BadRequestHeaders), + }; + } + + if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { + return Err(CorsError::HeadersNotAllowed); + } + return Ok(()); + } + Err(CorsError::BadRequestHeaders) + } else { + Err(CorsError::MissingRequestHeaders) + } + } + } + } +} + +impl Service for CorsMiddleware +where + S: Service, Response = ServiceResponse>, + S::Future: 'static, + S::Error: 'static, + P: 'static, + B: 'static, +{ + type Request = ServiceRequest

; + type Response = ServiceResponse; + type Error = S::Error; + type Future = Either< + FutureResult, + Either>>, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.service.poll_ready() + } + + fn call(&mut self, req: ServiceRequest

) -> Self::Future { + if self.inner.preflight && Method::OPTIONS == *req.method() { + if let Err(e) = self + .inner + .validate_origin(&req) + .and_then(|_| self.inner.validate_allowed_method(&req)) + .and_then(|_| self.inner.validate_allowed_headers(&req)) + { + return Either::A(ok(req.error_response(e))); + } + + // allowed headers + let headers = if let Some(headers) = self.inner.headers.as_ref() { + Some( + HeaderValue::try_from( + &headers + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .unwrap(), + ) + } else if let Some(hdr) = + req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) + { + Some(hdr.clone()) + } else { + None + }; + + let res = HttpResponse::Ok() + .if_some(self.inner.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( + header::ACCESS_CONTROL_MAX_AGE, + format!("{}", max_age).as_str(), + ); + }) + .if_some(headers, |headers, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); + }) + .if_some( + self.inner.access_control_allow_origin(&req), + |origin, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }, + ) + .if_true(self.inner.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + &self + .inner + .methods + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .finish() + .into_body(); + + Either::A(ok(req.into_response(res))) + } else if req.headers().contains_key(header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(&req) { + return Either::A(ok(req.error_response(e))); + } + + let inner = self.inner.clone(); + + Either::B(Either::B(Box::new(self.service.call(req).and_then( + move |mut res| { + if let Some(origin) = + inner.access_control_allow_origin(&res.request()) + { + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = + if let Some(hdr) = res.headers_mut().get(header::VARY) { + let mut val: Vec = + Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + HeaderValue::try_from(&val[..]).unwrap() + } else { + HeaderValue::from_static("Origin") + }; + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + }, + )))) + } else { + Either::B(Either::A(self.service.call(req))) + } } } #[cfg(test)] mod tests { - use super::*; - use test::{self, TestRequest}; + use actix_service::{FnService, Transform}; - impl Started { - fn is_done(&self) -> bool { - match *self { - Started::Done => true, - _ => false, - } - } - fn response(self) -> HttpResponse { - match self { - Started::Response(resp) => resp, - _ => panic!(), - } - } - } - impl Response { - fn response(self) -> HttpResponse { - match self { - Response::Done(resp) => resp, - _ => panic!(), - } + use super::*; + use crate::dev::PayloadStream; + use crate::test::{self, block_on, TestRequest}; + + impl Cors { + fn finish(self, srv: S) -> CorsMiddleware + where + S: Service, Response = ServiceResponse> + + 'static, + S::Future: 'static, + S::Error: 'static, + P: 'static, + B: 'static, + { + block_on( + IntoTransform::::into_transform(self).new_transform(srv), + ) + .unwrap() } } #[test] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] fn cors_validates_illegal_allow_credentials() { - Cors::build() + let _cors = Cors::new() .supports_credentials() .send_wildcard() - .finish(); - } - - #[test] - #[should_panic(expected = "No resources are registered")] - fn no_resource() { - Cors::build() - .supports_credentials() - .send_wildcard() - .register(); - } - - #[test] - #[should_panic(expected = "Cors::for_app(app)")] - fn no_resource2() { - Cors::build() - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) - .register(); + .finish(test::ok_service()); } #[test] fn validate_origin_allows_all_origins() { - let cors = Cors::default(); - let req = TestRequest::with_header("Origin", "https://www.example.com").finish(); + let mut cors = Cors::new().finish(test::ok_service()); + let req = + TestRequest::with_header("Origin", "https://www.example.com").to_service(); - assert!(cors.start(&req).ok().unwrap().is_done()) + let resp = test::call_success(&mut cors, req); + assert_eq!(resp.status(), StatusCode::OK); } #[test] fn test_preflight() { - let mut cors = Cors::build() + let mut cors = Cors::new() .send_wildcard() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE) - .finish(); + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) - .finish(); + .to_service(); - assert!(cors.start(&req).is_err()); + assert!(cors.inner.validate_allowed_method(&req).is_err()); + assert!(cors.inner.validate_allowed_headers(&req).is_err()); let req = TestRequest::with_header("Origin", "https://www.example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") .method(Method::OPTIONS) - .finish(); + .to_service(); - assert!(cors.start(&req).is_err()); + assert!(cors.inner.validate_allowed_method(&req).is_err()); + assert!(cors.inner.validate_allowed_headers(&req).is_err()); let req = TestRequest::with_header("Origin", "https://www.example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header( header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT", - ).method(Method::OPTIONS) - .finish(); + ) + .method(Method::OPTIONS) + .to_service(); - let resp = cors.start(&req).unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"*"[..], resp.headers() @@ -990,16 +904,39 @@ mod tests { .unwrap() .as_bytes() ); - //assert_eq!( - // &b"authorization,accept,content-type"[..], - // resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap(). - // as_bytes()); assert_eq!( - // &b"POST,GET,OPTIONS"[..], - // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap(). - // as_bytes()); + let hdr = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_HEADERS) + .unwrap() + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) + .unwrap() + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); Rc::get_mut(&mut cors.inner).unwrap().preflight = false; - assert!(cors.start(&req).unwrap().is_done()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_service(); + + let resp = test::call_success(&mut cors, req); + assert_eq!(resp.status(), StatusCode::OK); } // #[test] @@ -1015,46 +952,47 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn test_validate_not_allowed_origin() { - let cors = Cors::build() + let cors = Cors::new() .allowed_origin("https://www.example.com") - .finish(); + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://www.unknown.com") .method(Method::GET) - .finish(); - cors.start(&req).unwrap(); + .to_service(); + cors.inner.validate_origin(&req).unwrap(); + cors.inner.validate_allowed_method(&req).unwrap(); + cors.inner.validate_allowed_headers(&req).unwrap(); } #[test] fn test_validate_origin() { - let cors = Cors::build() + let mut cors = Cors::new() .allowed_origin("https://www.example.com") - .finish(); + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::GET) - .finish(); + .to_service(); - assert!(cors.start(&req).unwrap().is_done()); + let resp = test::call_success(&mut cors, req); + assert_eq!(resp.status(), StatusCode::OK); } #[test] fn test_no_origin_response() { - let cors = Cors::build().finish(); + let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); - let req = TestRequest::default().method(Method::GET).finish(); - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&req, resp).unwrap().response(); - assert!( - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .is_none() - ); + let req = TestRequest::default().method(Method::GET).to_service(); + let resp = test::call_success(&mut cors, req); + assert!(resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none()); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) - .finish(); - let resp = cors.response(&req, resp).unwrap().response(); + .to_service(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"https://www.example.com"[..], resp.headers() @@ -1067,7 +1005,7 @@ mod tests { #[test] fn test_response() { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let cors = Cors::build() + let mut cors = Cors::new() .send_wildcard() .disable_preflight() .max_age(3600) @@ -1075,14 +1013,13 @@ mod tests { .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) - .finish(); + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) - .finish(); + .to_service(); - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&req, resp).unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"*"[..], resp.headers() @@ -1111,21 +1048,40 @@ mod tests { } } - let resp: HttpResponse = - HttpResponse::Ok().header(header::VARY, "Accept").finish(); - let resp = cors.response(&req, resp).unwrap().response(); + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish(FnService::new(move |req: ServiceRequest| { + req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + ) + })); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_service(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"Accept, Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes() ); - let cors = Cors::build() + let mut cors = Cors::new() .disable_vary_header() .allowed_origin("https://www.example.com") .allowed_origin("https://www.google.com") - .finish(); - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&req, resp).unwrap().response(); + .finish(test::ok_service()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_service(); + let resp = test::call_success(&mut cors, req); let origins_str = resp .headers() @@ -1134,61 +1090,22 @@ mod tests { .to_str() .unwrap(); - assert_eq!( - "https://www.example.com", - origins_str - ); - } - - #[test] - fn cors_resource() { - let mut srv = test::TestServer::with_factory(|| { - App::new().configure(|app| { - Cors::for_app(app) - .allowed_origin("https://www.example.com") - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) - .register() - }) - }); - - let request = srv - .get() - .uri(srv.url("/test")) - .header("ORIGIN", "https://www.example2.com") - .finish() - .unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - - let request = srv.get().uri(srv.url("/test")).finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::OK); - - let request = srv - .get() - .uri(srv.url("/test")) - .header("ORIGIN", "https://www.example.com") - .finish() - .unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::OK); + assert_eq!("https://www.example.com", origins_str); } #[test] fn test_multiple_origins() { - let cors = Cors::build() + let mut cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish(); - + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://example.com") .method(Method::GET) - .finish(); - let resp: HttpResponse = HttpResponse::Ok().into(); + .to_service(); - let resp = cors.response(&req, resp).unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1199,10 +1116,9 @@ mod tests { let req = TestRequest::with_header("Origin", "https://example.org") .method(Method::GET) - .finish(); - let resp: HttpResponse = HttpResponse::Ok().into(); + .to_service(); - let resp = cors.response(&req, resp).unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"https://example.org"[..], resp.headers() @@ -1214,19 +1130,18 @@ mod tests { #[test] fn test_multiple_origins_preflight() { - let cors = Cors::build() + let mut cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish(); - + .finish(test::ok_service()); let req = TestRequest::with_header("Origin", "https://example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .method(Method::OPTIONS) - .finish(); + .to_service(); - let resp = cors.start(&req).ok().unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1238,9 +1153,9 @@ mod tests { let req = TestRequest::with_header("Origin", "https://example.org") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .method(Method::OPTIONS) - .finish(); + .to_service(); - let resp = cors.start(&req).ok().unwrap().response(); + let resp = test::call_success(&mut cors, req); assert_eq!( &b"https://example.org"[..], resp.headers() diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index bca2cf6e0..4d879cda7 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -154,18 +154,15 @@ mod tests { use super::*; use crate::dev::ServiceRequest; use crate::http::header::CONTENT_TYPE; - use crate::test::{block_on, TestRequest}; + use crate::test::{block_on, ok_service, TestRequest}; use crate::HttpResponse; #[test] fn test_default_headers() { - let srv = FnService::new(|req: ServiceRequest<_>| { - req.into_response(HttpResponse::Ok().finish()) - }); let mut mw = block_on( DefaultHeaders::new() .header(CONTENT_TYPE, "0001") - .new_transform(srv), + .new_transform(ok_service()), ) .unwrap(); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 9b13a20ab..b997cca28 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -3,6 +3,7 @@ mod compress; #[cfg(any(feature = "brotli", feature = "flate2"))] pub use self::compress::Compress; +pub mod cors; mod defaultheaders; mod errhandlers; mod logger; diff --git a/src/test.rs b/src/test.rs index ed9cf27cc..7e79659af 100644 --- a/src/test.rs +++ b/src/test.rs @@ -3,13 +3,13 @@ use std::cell::RefCell; use std::rc::Rc; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; -use actix_http::http::{HttpTryFrom, Method, Version}; +use actix_http::http::{HttpTryFrom, Method, StatusCode, Version}; use actix_http::test::TestRequest as HttpTestRequest; use actix_http::{Extensions, PayloadStream, Request}; use actix_router::{Path, ResourceDef, Url}; use actix_rt::Runtime; use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{FnService, IntoNewService, NewService, Service}; use bytes::Bytes; #[cfg(feature = "cookies")] use cookie::Cookie; @@ -17,9 +17,10 @@ use futures::future::{lazy, Future}; use crate::config::{AppConfig, AppConfigInner}; use crate::data::RouteData; +use crate::dev::Body; use crate::rmap::ResourceMap; use crate::service::{ServiceFromRequest, ServiceRequest, ServiceResponse}; -use crate::{HttpRequest, HttpResponse}; +use crate::{Error, HttpRequest, HttpResponse}; thread_local! { static RT: RefCell = { @@ -55,6 +56,26 @@ where RT.with(move |rt| rt.borrow_mut().block_on(lazy(f))) } +pub fn ok_service() -> impl Service< + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, +> { + default_service(StatusCode::OK) +} + +pub fn default_service( + status_code: StatusCode, +) -> impl Service< + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, +> { + FnService::new(move |req: ServiceRequest| { + req.into_response(HttpResponse::build(status_code).finish()) + }) +} + /// This method accepts application builder instance, and constructs /// service. ///