From 6084c47810fd3e2aeac564b1b47d3a7fedc9311e Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 19 Oct 2020 05:51:31 +0100 Subject: [PATCH] CORS builder rework (#119) --- actix-cors/CHANGES.md | 10 +- actix-cors/Cargo.toml | 15 +- actix-cors/examples/cors.rs | 43 ++++ actix-cors/src/all_or_some.rs | 13 +- actix-cors/src/builder.rs | 462 +++++++++++++++++++++------------- actix-cors/src/error.rs | 9 +- actix-cors/src/inner.rs | 278 +++++++++++--------- actix-cors/src/lib.rs | 18 +- actix-cors/src/middleware.rs | 263 +++++++++++-------- actix-cors/tests/tests.rs | 107 ++++---- 10 files changed, 730 insertions(+), 488 deletions(-) create mode 100644 actix-cors/examples/cors.rs diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index 90646a1e3..07d4863ef 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -2,10 +2,18 @@ ## Unreleased - 2020-xx-xx * Disallow `*` in `Cors::allowed_origin` by panicking. [#114]. -* Hide `CorsMiddleware` from rustdocs. [#118]. +* Hide `CorsMiddleware` from docs. [#118]. +* `CorsFactory` is removed. [#119] +* The `impl Default` constructor is now overly-restrictive. [#119] +* Added `Cors::permissive()` constructor that allows anything. [#119] +* Adds methods for each property to reset to a permissive state. (`allow_any_origin`, `expose_any_header`, etc.) [#119] +* Errors are now propagated with `Transform::InitError` instead of panicking. [#119] +* Fixes bug where allowed origin functions are not called if `allowed_origins` is All. [#119] +* `AllOrSome` is no longer public. [#119] [#114]: https://github.com/actix/actix-extras/pull/114 [#118]: https://github.com/actix/actix-extras/pull/118 +[#119]: https://github.com/actix/actix-extras/pull/119 ## 0.4.1 - 2020-10-07 diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 9c73b0efe..d49a83de9 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "actix-cors" version = "0.4.1" -authors = ["Nikolay Kim "] +authors = [ + "Nikolay Kim ", + "Rob Ede " +] description = "Cross-Origin Resource Sharing (CORS) controls for Actix Web" readme = "README.md" keywords = ["actix", "cors", "web", "security", "crossorigin"] @@ -19,8 +22,12 @@ path = "src/lib.rs" actix-web = { version = "3.0.0", default-features = false } derive_more = "0.99.2" futures-util = { version = "0.3.4", default-features = false } +log = "0.4" +once_cell = "1" +tinyvec = "1.0.0" [dev-dependencies] -actix-service = "1.0.6" -actix-rt = "1.1.1" -regex = "1.3.9" +actix-service = "1" +actix-rt = "1" +pretty_env_logger = "0.4" +regex = "1.4" diff --git a/actix-cors/examples/cors.rs b/actix-cors/examples/cors.rs new file mode 100644 index 000000000..ca534b550 --- /dev/null +++ b/actix-cors/examples/cors.rs @@ -0,0 +1,43 @@ +use actix_cors::Cors; +use actix_web::{http::header, web, App, HttpServer}; + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + pretty_env_logger::init(); + + HttpServer::new(move || { + App::new() + .wrap( + // default settings are overly restrictive to reduce chance of + // misconfiguration leading to security concerns + Cors::default() + // add specific origin to allowed origin list + .allowed_origin("http://project.local:8080") + // allow any port on localhost + .allowed_origin_fn(|req_head| { + // unwrapping is acceptable on the origin header since this function is + // only called when it exists + req_head + .headers() + .get(header::ORIGIN) + .unwrap() + .as_bytes() + .starts_with(b"http://localhost") + }) + // set allowed methods list + .allowed_methods(vec!["GET", "POST"]) + // set allowed request header list + .allowed_headers(&[header::AUTHORIZATION, header::ACCEPT]) + // add header to allowed list + .allowed_header(header::CONTENT_TYPE) + // set list of headers that are safe to expose + .expose_headers(&[header::CONTENT_DISPOSITION]) + // set CORS rules ttl + .max_age(3600), + ) + .default_service(web::to(|| async { "Hello world!" })) + }) + .bind("127.0.0.1:8080")? + .run() + .await +} diff --git a/actix-cors/src/all_or_some.rs b/actix-cors/src/all_or_some.rs index 065bf29e5..d8cec53a1 100644 --- a/actix-cors/src/all_or_some.rs +++ b/actix-cors/src/all_or_some.rs @@ -1,5 +1,5 @@ /// An enum signifying that some of type `T` is allowed, or `All` (anything is allowed). -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum AllOrSome { /// Everything is allowed. Usually equivalent to the `*` value. All, @@ -22,17 +22,26 @@ impl AllOrSome { } /// Returns whether this is a `Some` variant. + #[allow(dead_code)] pub fn is_some(&self) -> bool { !self.is_all() } - /// Returns &T. + /// Provides a shared reference to `T` if variant is `Some`. pub fn as_ref(&self) -> Option<&T> { match *self { AllOrSome::All => None, AllOrSome::Some(ref t) => Some(t), } } + + /// Provides a mutable reference to `T` if variant is `Some`. + pub fn as_mut(&mut self) -> Option<&mut T> { + match *self { + AllOrSome::All => None, + AllOrSome::Some(ref mut t) => Some(t), + } + } } #[cfg(test)] diff --git a/actix-cors/src/builder.rs b/actix-cors/src/builder.rs index 4a08387ca..20780e20a 100644 --- a/actix-cors/src/builder.rs +++ b/actix-cors/src/builder.rs @@ -3,99 +3,117 @@ use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc}; use actix_web::{ dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform}, error::{Error, Result}, - http::{self, header::HeaderName, Error as HttpError, Method, Uri}, + http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri}, + Either, }; -use futures_util::future::{ok, Ready}; +use futures_util::future::{self, Ready}; +use log::error; +use once_cell::sync::Lazy; +use tinyvec::tiny_vec; -use crate::{cors, AllOrSome, CorsMiddleware, Inner, OriginFn}; +use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn}; + +pub(crate) fn cors<'a>( + inner: &'a mut Rc, + err: &Option>, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + + Rc::get_mut(inner) +} + +static ALL_METHODS_SET: Lazy> = Lazy::new(|| { + HashSet::from_iter(vec![ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::HEAD, + Method::OPTIONS, + Method::CONNECT, + Method::PATCH, + Method::TRACE, + ]) +}); /// Builder for CORS middleware. /// -/// To construct a CORS middleware: +/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder. +/// Then use any of the builder methods to customize CORS behavior. /// -/// 1. Call [`Cors::new()`] to start building. -/// 2. Use any of the builder methods to customize CORS behavior. -/// 3. Call [`Cors::finish()`] to build the middleware. +/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing +/// all origins and headers, etc. **The permissive constructor should not be used in production.** +/// +/// # Errors +/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled +/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error +/// messages will outline what is wrong with the CORS configuration in the server logs and the +/// server will fail to start up or serve requests. /// /// # Example -/// /// ```rust -/// use actix_cors::{Cors, CorsFactory}; +/// use actix_cors::Cors; /// use actix_web::http::header; /// -/// let cors = Cors::new() +/// let cors = Cors::default() /// .allowed_origin("https://www.rust-lang.org") /// .allowed_methods(vec!["GET", "POST"]) /// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) /// .allowed_header(header::CONTENT_TYPE) -/// .max_age(3600) -/// .finish(); +/// .max_age(3600); /// /// // `cors` can now be used in `App::wrap`. /// ``` -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Cors { - cors: Option, - methods: bool, - error: Option, - expose_headers: HashSet, + inner: Rc, + error: Option>, } impl Cors { - /// Return a new builder. - pub fn new() -> Cors { - Cors { - cors: Some(Inner { - origins: AllOrSome::All, - origins_str: None, - origins_fns: Vec::new(), - methods: HashSet::new(), - headers: AllOrSome::All, - expose_headers: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }), - methods: false, - error: None, - expose_headers: HashSet::new(), - } - } - - /// Build a CORS middleware with default settings. - pub fn default() -> CorsFactory { + /// A very permissive set of default for quick development. Not recommended for production use. + /// + /// *All* origins, methods, request headers and exposed headers allowed. Credentials supported. + /// Max age 1 hour. Does not send wildcard. + pub fn permissive() -> Self { let inner = Inner { - origins: AllOrSome::default(), - origins_str: None, - origins_fns: Vec::new(), - methods: HashSet::from_iter( - vec![ - Method::GET, - Method::HEAD, - Method::POST, - Method::OPTIONS, - Method::PUT, - Method::PATCH, - Method::DELETE, - ] - .into_iter(), - ), - headers: AllOrSome::All, - expose_headers: None, - max_age: None, + allowed_origins: AllOrSome::All, + allowed_origins_fns: tiny_vec![], + + allowed_methods: ALL_METHODS_SET.clone(), + allowed_methods_baked: None, + + allowed_headers: AllOrSome::All, + allowed_headers_baked: None, + + expose_headers: AllOrSome::All, + expose_headers_baked: None, + max_age: Some(3600), preflight: true, send_wildcard: false, - supports_credentials: false, + supports_credentials: true, vary_header: true, }; - CorsFactory { + Cors { inner: Rc::new(inner), + error: None, } } + /// Resets allowed origin list to a state where any origin is accepted. + /// + /// See [`Cors::allowed_origin`] for more info on allowed origins. + pub fn allow_any_origin(mut self) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allowed_origins = AllOrSome::All; + } + + self + } + /// Add an origin that is allowed to make requests. /// /// By default, requests from all origins are accepted by CORS logic. This method allows to @@ -115,32 +133,34 @@ impl Cors { /// `allowed_origin_fn` function is set, these functions will be used to determinate /// allowed origins. /// - /// # Panics - /// - /// * If supplied origin is not valid uri, or - /// * If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead. + /// # Initialization Errors + /// - If supplied origin is not valid uri + /// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead. /// /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header pub fn allowed_origin(mut self, origin: &str) -> Cors { - assert!( - origin != "*", - "Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`." - ); - - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { match TryInto::::try_into(origin) { + Ok(_) if origin == "*" => { + error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`."); + self.error = Some(Either::B(CorsError::WildcardOrigin)); + } + Ok(_) => { - if cors.origins.is_all() { - cors.origins = AllOrSome::Some(HashSet::new()); + if cors.allowed_origins.is_all() { + cors.allowed_origins = + AllOrSome::Some(HashSet::with_capacity(8)); } - if let AllOrSome::Some(ref mut origins) = cors.origins { - origins.insert(origin.to_owned()); + if let Some(origins) = cors.allowed_origins.as_mut() { + // any uri is a valid header value + let hv = origin.try_into().unwrap(); + origins.insert(hv); } } - Err(e) => { - self.error = Some(e.into()); + Err(err) => { + self.error = Some(Either::A(err.into())); } } } @@ -160,15 +180,26 @@ impl Cors { where F: (Fn(&RequestHead) -> bool) + 'static, { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.origins_fns.push(OriginFn { - boxed_fn: Box::new(f), + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allowed_origins_fns.push(OriginFn { + boxed_fn: Rc::new(f), }); } self } + /// Resets allowed methods list to all methods. + /// + /// See [`Cors::allowed_methods`] for more info on allowed methods. + pub fn allow_any_method(mut self) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allowed_methods = ALL_METHODS_SET.clone(); + } + + self + } + /// Set a list of methods which allowed origins can perform. /// /// These will be sent in the `Access-Control-Allow-Methods` response header as specified in @@ -183,16 +214,15 @@ impl Cors { M: TryInto, >::Error: Into, { - self.methods = true; - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { for m in methods { match m.try_into() { Ok(method) => { - cors.methods.insert(method); + cors.allowed_methods.insert(method); } - Err(e) => { - self.error = Some(e.into()); + Err(err) => { + self.error = Some(Either::A(err.into())); break; } } @@ -202,34 +232,46 @@ impl Cors { self } - /// Add an allowed header. + /// Resets allowed request header list to a state where any origin is accepted. /// - /// See `Cors::allowed_headers()` for details. + /// See [`Cors::allowed_headers`] for more info on allowed request headers. + pub fn allow_any_header(mut self) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allowed_origins = AllOrSome::All; + } + + self + } + + /// Add an allowed request header. + /// + /// See [`Cors::allowed_headers`] for more info on allowed request headers. pub fn allowed_header(mut self, header: H) -> Cors where H: TryInto, >::Error: Into, { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { match header.try_into() { Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); + if cors.allowed_headers.is_all() { + cors.allowed_headers = + AllOrSome::Some(HashSet::with_capacity(8)); } - if let AllOrSome::Some(ref mut headers) = cors.headers { + if let AllOrSome::Some(ref mut headers) = cors.allowed_headers { headers.insert(method); } } - Err(e) => self.error = Some(e.into()), + Err(err) => self.error = Some(Either::A(err.into())), } } self } - /// Set a list of header field names which can be used when this resource is accessed by + /// Set a list of request header field names which can be used when this resource is accessed by /// allowed origins. /// /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers` @@ -245,19 +287,21 @@ impl Cors { H: TryInto, >::Error: Into, { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { for h in headers { match h.try_into() { Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); + if cors.allowed_headers.is_all() { + cors.allowed_headers = + AllOrSome::Some(HashSet::with_capacity(8)); } - if let AllOrSome::Some(ref mut headers) = cors.headers { + + if let AllOrSome::Some(ref mut headers) = cors.allowed_headers { headers.insert(method); } } - Err(e) => { - self.error = Some(e.into()); + Err(err) => { + self.error = Some(Either::A(err.into())); break; } } @@ -267,6 +311,17 @@ impl Cors { self } + /// Resets exposed response header list to a state where any header is accepted. + /// + /// See [`Cors::expose_headers`] for more info on exposed response headers. + pub fn expose_any_header(mut self) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allowed_origins = AllOrSome::All; + } + + self + } + /// Set a list of headers which are safe to expose to the API of a CORS API specification. /// This corresponds to the `Access-Control-Expose-Headers` response header as specified in /// the [Fetch Standard CORS protocol]. @@ -282,11 +337,19 @@ impl Cors { { for h in headers { match h.try_into() { - Ok(method) => { - self.expose_headers.insert(method); + Ok(header) => { + if let Some(cors) = cors(&mut self.inner, &self.error) { + if cors.expose_headers.is_all() { + cors.expose_headers = + AllOrSome::Some(HashSet::with_capacity(8)); + } + if let AllOrSome::Some(ref mut headers) = cors.expose_headers { + headers.insert(header); + } + } } - Err(e) => { - self.error = Some(e.into()); + Err(err) => { + self.error = Some(Either::A(err.into())); break; } } @@ -295,16 +358,16 @@ impl Cors { self } - /// Set a maximum time for which this CORS request maybe cached. + /// Set a maximum time (in seconds) for which this CORS request maybe cached. /// This value is set as the `Access-Control-Max-Age` header as specified in /// the [Fetch Standard CORS protocol]. /// - /// This defaults to `None` (unset). + /// Pass a number (of seconds) or use None to disable sending max age header. /// /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol - pub fn max_age(mut self, max_age: usize) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.max_age = Some(max_age) + pub fn max_age(mut self, max_age: impl Into>) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.max_age = max_age.into() } self @@ -322,7 +385,7 @@ impl Cors { /// /// Defaults to `false`. pub fn send_wildcard(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { cors.send_wildcard = true } @@ -340,12 +403,12 @@ impl Cors { /// /// Defaults to `false`. /// - /// Builder panics during `finish` if credentials are allowed, but the Origin is set to `*`. - /// This is not allowed by the CORS protocol. + /// A server initialization error will occur if credentials are allowed, but the Origin is set + /// to send wildcards (`*`); this is not allowed by the CORS protocol. /// /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol pub fn supports_credentials(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { cors.supports_credentials = true } @@ -363,7 +426,7 @@ impl Cors { /// /// By default, `Vary` header support is enabled. pub fn disable_vary_header(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { cors.vary_header = false } @@ -377,71 +440,48 @@ impl Cors { /// /// By default *preflight* support is enabled. pub fn disable_preflight(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { + if let Some(cors) = cors(&mut self.inner, &self.error) { cors.preflight = false } self } +} - /// Construct CORS middleware. - pub fn finish(self) -> CorsFactory { - let mut this = if !self.methods { - self.allowed_methods(vec![ - Method::GET, - Method::HEAD, - Method::POST, - Method::OPTIONS, - Method::PUT, - Method::PATCH, - Method::DELETE, - ]) - } else { - self +impl Default for Cors { + /// A restrictive (security paranoid) set of defaults. + /// + /// *No* allowed origins, methods, request headers or exposed headers. Credentials + /// not supported. No max age (will use browser's default). + fn default() -> Cors { + let inner = Inner { + allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)), + allowed_origins_fns: tiny_vec![], + + allowed_methods: HashSet::with_capacity(8), + allowed_methods_baked: None, + + allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)), + allowed_headers_baked: None, + + expose_headers: AllOrSome::Some(HashSet::with_capacity(8)), + expose_headers_baked: None, + + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, }; - if let Some(e) = this.error.take() { - panic!("{}", e); - } - - let mut cors = this.cors.take().expect("cannot reuse CorsBuilder"); - - if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { - panic!("Credentials are allowed, but the Origin is set to \"*\""); - } - - if let AllOrSome::Some(ref origins) = cors.origins { - let s = origins - .iter() - .fold(String::new(), |s, v| format!("{}, {}", s, v)); - cors.origins_str = Some(s[2..].try_into().unwrap()); - } - - if !this.expose_headers.is_empty() { - cors.expose_headers = Some( - this.expose_headers - .iter() - .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] - .to_owned(), - ); - } - - CorsFactory { - inner: Rc::new(cors), + Cors { + inner: Rc::new(inner), + error: None, } } } -/// Middleware for Cross-Origin Resource Sharing support. -/// -/// This struct contains the settings for CORS requests to be validated and for responses to -/// be generated. -#[derive(Debug)] -pub struct CorsFactory { - inner: Rc, -} - -impl Transform for CorsFactory +impl Transform for Cors where S: Service, Error = Error>, S::Future: 'static, @@ -455,13 +495,74 @@ where type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(CorsMiddleware { - service, - inner: Rc::clone(&self.inner), - }) + if let Some(ref err) = self.error { + match err { + Either::A(err) => error!("{}", err), + Either::B(err) => error!("{}", err), + } + + return future::err(()); + } + + let mut inner = Rc::clone(&self.inner); + + if inner.supports_credentials + && inner.send_wildcard + && inner.allowed_origins.is_all() + { + error!("Illegal combination of CORS options: credentials can not be supported when all \ + origins are allowed and `send_wildcard` is enabled."); + return future::err(()); + } + + // bake allowed headers value if Some and not empty + match inner.allowed_headers.as_ref() { + Some(header_set) if !header_set.is_empty() => { + let allowed_headers_str = intersperse_header_values(header_set); + Rc::make_mut(&mut inner).allowed_headers_baked = + Some(allowed_headers_str); + } + _ => {} + } + + // bake allowed methods value if not empty + if !inner.allowed_methods.is_empty() { + let allowed_methods_str = intersperse_header_values(&inner.allowed_methods); + Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str); + } + + // bake exposed headers value if Some and not empty + match inner.expose_headers.as_ref() { + Some(header_set) if !header_set.is_empty() => { + let expose_headers_str = intersperse_header_values(header_set); + Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str); + } + _ => {} + } + + future::ok(CorsMiddleware { service, inner }) } } +/// Only call when values are guaranteed to be valid header values and set is not empty. +fn intersperse_header_values(val_set: &HashSet) -> HeaderValue +where + T: AsRef, +{ + val_set + .iter() + .fold(String::with_capacity(32), |mut acc, val| { + acc.push_str(", "); + acc.push_str(val.as_ref()); + acc + }) + // set is not empty so string will always have leading ", " to trim + [2..] + .try_into() + // all method names are valid header values + .unwrap() +} + #[cfg(test)] mod test { use std::convert::{Infallible, TryInto}; @@ -475,13 +576,20 @@ mod test { use super::*; #[test] - #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] - fn cors_validates_illegal_allow_credentials() { - let _cors = Cors::new().supports_credentials().send_wildcard().finish(); + fn illegal_allow_credentials() { + // using the permissive defaults (all origins allowed) and adding send_wildcard + // and supports_credentials should error on construction + + assert!(Cors::permissive() + .supports_credentials() + .send_wildcard() + .new_transform(test::ok_service()) + .into_inner() + .is_err()); } #[actix_rt::test] - async fn default() { + async fn restrictive_defaults() { let mut cors = Cors::default() .new_transform(test::ok_service()) .await @@ -491,12 +599,12 @@ mod test { .to_srv_request(); let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } #[actix_rt::test] async fn allowed_header_try_from() { - let _cors = Cors::new().allowed_header("Content-Type"); + let _cors = Cors::default().allowed_header("Content-Type"); } #[actix_rt::test] @@ -511,6 +619,6 @@ mod test { } } - let _cors = Cors::new().allowed_header(ContentType); + let _cors = Cors::default().allowed_header(ContentType); } } diff --git a/actix-cors/src/error.rs b/actix-cors/src/error.rs index a986e08c7..09a316046 100644 --- a/actix-cors/src/error.rs +++ b/actix-cors/src/error.rs @@ -4,15 +4,16 @@ use derive_more::{Display, Error}; /// Errors that can occur when processing CORS guarded requests. #[derive(Debug, Clone, Display, Error)] +#[non_exhaustive] pub enum CorsError { + /// Allowed origin argument must not be wildcard (`*`). + #[display(fmt = "`allowed_origin` argument must not be wildcard (`*`).")] + WildcardOrigin, + /// Request header `Origin` is required but was not provided. #[display(fmt = "Request header `Origin` is required but was not provided.")] MissingOrigin, - /// Request header `Origin` could not be parsed correctly. - #[display(fmt = "Request header `Origin` could not be parsed correctly.")] - BadOrigin, - /// Request header `Access-Control-Request-Method` is required but is missing. #[display( fmt = "Request header `Access-Control-Request-Method` is required but is missing." diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs index 1b311ed22..cf031504f 100644 --- a/actix-cors/src/inner.rs +++ b/actix-cors/src/inner.rs @@ -1,30 +1,29 @@ -use std::{collections::HashSet, convert::TryInto, fmt}; +use std::{collections::HashSet, convert::TryFrom, convert::TryInto, fmt, rc::Rc}; use actix_web::{ dev::RequestHead, error::Result, http::{ - self, header::{self, HeaderName, HeaderValue}, Method, }, }; +use once_cell::sync::Lazy; +use tinyvec::TinyVec; use crate::{AllOrSome, CorsError}; -pub(crate) fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; - } - - parts.as_mut() +#[derive(Clone)] +pub(crate) struct OriginFn { + pub(crate) boxed_fn: Rc bool>, } -pub(crate) struct OriginFn { - pub(crate) boxed_fn: Box bool>, +impl Default for OriginFn { + /// Dummy default for use in tiny_vec. Do not use. + fn default() -> Self { + let boxed_fn: Rc _> = Rc::new(|_req_head| false); + Self { boxed_fn } + } } impl fmt::Debug for OriginFn { @@ -33,14 +32,28 @@ impl fmt::Debug for OriginFn { } } -#[derive(Debug)] +/// Try to parse header value as HTTP method. +fn header_value_try_into_method(hdr: &HeaderValue) -> Option { + hdr.to_str() + .ok() + .and_then(|meth| Method::try_from(meth).ok()) +} + +#[derive(Debug, Clone)] pub(crate) struct Inner { - pub(crate) methods: HashSet, - pub(crate) origins: AllOrSome>, - pub(crate) origins_fns: Vec, - pub(crate) origins_str: Option, - pub(crate) headers: AllOrSome>, - pub(crate) expose_headers: Option, + pub(crate) allowed_origins: AllOrSome>, + pub(crate) allowed_origins_fns: TinyVec<[OriginFn; 4]>, + + pub(crate) allowed_methods: HashSet, + pub(crate) allowed_methods_baked: Option, + + pub(crate) allowed_headers: AllOrSome>, + pub(crate) allowed_headers_baked: Option, + + /// `All` will echo back `Access-Control-Request-Header` list. + pub(crate) expose_headers: AllOrSome>, + pub(crate) expose_headers_baked: Option, + pub(crate) max_age: Option, pub(crate) preflight: bool, pub(crate) send_wildcard: bool, @@ -48,90 +61,94 @@ pub(crate) struct Inner { pub(crate) vary_header: bool, } +static EMPTY_ORIGIN_SET: Lazy> = Lazy::new(HashSet::new); + impl Inner { pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ORIGIN) { - if let Ok(origin) = hdr.to_str() { - return match self.origins { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => allowed_origins - .get(origin) - .map(|_| ()) - .or_else(|| { - if self.validate_origin_fns(req) { - Some(()) - } else { - None - } - }) - .ok_or(CorsError::OriginNotAllowed), - }; - } - Err(CorsError::BadOrigin) - } else { - match self.origins { - AllOrSome::All => Ok(()), - _ => Err(CorsError::MissingOrigin), + // return early if all origins are allowed or get ref to allowed origins set + #[allow(clippy::mutable_key_type)] + let allowed_origins = match &self.allowed_origins { + AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(()), + AllOrSome::Some(allowed_origins) => allowed_origins, + // only function origin validators are defined + _ => &EMPTY_ORIGIN_SET, + }; + + // get origin header and try to parse as string + match req.headers().get(header::ORIGIN) { + // origin header exists and is a string + Some(origin) => { + if allowed_origins.contains(origin) || self.validate_origin_fns(req) { + Ok(()) + } else { + Err(CorsError::OriginNotAllowed) + } } + + // origin header is missing + // note: with our implementation, the origin header is required for OPTIONS request or + // else this would be unreachable + None => Err(CorsError::MissingOrigin), } } + /// Accepts origin if _ANY_ functions return true. pub(crate) fn validate_origin_fns(&self, req: &RequestHead) -> bool { - self.origins_fns + self.allowed_origins_fns .iter() .any(|origin_fn| (origin_fn.boxed_fn)(req)) } + /// Only called if origin exists and always after it's validated. pub(crate) fn access_control_allow_origin( &self, req: &RequestHead, ) -> Option { - match self.origins { + let origin = req.headers().get(header::ORIGIN); + + match self.allowed_origins { AllOrSome::All => { if self.send_wildcard { Some(HeaderValue::from_static("*")) - } else if let Some(origin) = req.headers().get(header::ORIGIN) { - Some(origin.clone()) } else { - None + // see note below about why `.cloned()` is correct + origin.cloned() } } - AllOrSome::Some(ref origins) => { - if let Some(origin) = - req.headers() - .get(header::ORIGIN) - .filter(|o| match o.to_str() { - Ok(os) => origins.contains(os), - _ => false, - }) - { - Some(origin.clone()) - } else if self.validate_origin_fns(req) { - Some(req.headers().get(header::ORIGIN).unwrap().clone()) - } else { - Some(self.origins_str.as_ref().unwrap().clone()) - } + + AllOrSome::Some(_) => { + // since origin (if it exists) is known to be allowed if this method is called + // then cloning the option is all that is required to be used as an echoed back + // header value (or omitted if None) + origin.cloned() } } } + /// Use in preflight checks and therefore operates on header list in + /// `Access-Control-Request-Headers` not the actual header set. pub(crate) fn validate_allowed_method( &self, req: &RequestHead, ) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { - if let Ok(meth) = hdr.to_str() { - if let Ok(method) = meth.try_into() { - return self - .methods - .get(&method) - .map(|_| ()) - .ok_or(CorsError::MethodNotAllowed); - } - } - Err(CorsError::BadRequestMethod) - } else { - Err(CorsError::MissingRequestMethod) + // extract access control header and try to parse as method + let request_method = req + .headers() + .get(header::ACCESS_CONTROL_REQUEST_METHOD) + .map(header_value_try_into_method); + + match request_method { + // method valid and allowed + Some(Some(method)) if self.allowed_methods.contains(&method) => Ok(()), + + // method valid but not allowed + Some(Some(_)) => Err(CorsError::MethodNotAllowed), + + // method invalid + Some(_) => Err(CorsError::BadRequestMethod), + + // method missing + None => Err(CorsError::MissingRequestMethod), } } @@ -139,34 +156,54 @@ impl Inner { &self, req: &RequestHead, ) -> Result<(), CorsError> { - match self.headers { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_headers) => { - if let Some(hdr) = - req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) - { - if let Ok(headers) = hdr.to_str() { - #[allow(clippy::mutable_key_type)] // FIXME: revisit here - let mut validated_headers = HashSet::new(); - for hdr in headers.split(',') { - match hdr.trim().try_into() { - Ok(hdr) => validated_headers.insert(hdr), - Err(_) => return Err(CorsError::BadRequestHeaders), - }; - } - // `Access-Control-Request-Headers` must contain 1 or more `field-name` - if !validated_headers.is_empty() { - if !validated_headers.is_subset(allowed_headers) { - return Err(CorsError::HeadersNotAllowed); - } - return Ok(()); - } - } - Err(CorsError::BadRequestHeaders) - } else { - Ok(()) + // return early if all headers are allowed or get ref to allowed origins set + #[allow(clippy::mutable_key_type)] + let allowed_headers = match &self.allowed_headers { + AllOrSome::All => return Ok(()), + AllOrSome::Some(allowed_headers) => allowed_headers, + }; + + // extract access control header as string + // header format should be comma separated header names + let request_headers = req + .headers() + .get(header::ACCESS_CONTROL_REQUEST_HEADERS) + .map(|hdr| hdr.to_str()); + + match request_headers { + // header list is valid string + Some(Ok(headers)) => { + // the set is ephemeral we take care not to mutate the + // inserted keys so this lint exception is acceptable + #[allow(clippy::mutable_key_type)] + let mut request_headers = HashSet::with_capacity(8); + + // try to convert each header name in the comma-separated list + for hdr in headers.split(',') { + match hdr.trim().try_into() { + Ok(hdr) => request_headers.insert(hdr), + Err(_) => return Err(CorsError::BadRequestHeaders), + }; } + + // header list must contain 1 or more header name + if request_headers.is_empty() { + return Err(CorsError::BadRequestHeaders); + } + + // request header list must be a subset of allowed headers + if !request_headers.is_subset(allowed_headers) { + return Err(CorsError::HeadersNotAllowed); + } + + Ok(()) } + + // header list is not a string + Some(Err(_)) => Err(CorsError::BadRequestHeaders), + + // header list missing + None => Ok(()), } } } @@ -177,40 +214,43 @@ mod test { use actix_web::{ dev::Transform, - http::{header, Method, StatusCode}, + http::{header, HeaderValue, Method, StatusCode}, test::{self, TestRequest}, }; use crate::Cors; + fn val_as_str(val: &HeaderValue) -> &str { + val.to_str().unwrap() + } + #[actix_rt::test] - #[should_panic(expected = "OriginNotAllowed")] async fn test_validate_not_allowed_origin() { - let cors = Cors::new() + let cors = Cors::default() .allowed_origin("https://www.example.com") - .finish() .new_transform(test::ok_service()) .await .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) + let req = TestRequest::get() + .header(header::ORIGIN, "https://www.unknown.com") + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT") .to_srv_request(); - cors.inner.validate_origin(req.head()).unwrap(); - cors.inner.validate_allowed_method(req.head()).unwrap(); - cors.inner.validate_allowed_headers(req.head()).unwrap(); + assert!(cors.inner.validate_origin(req.head()).is_err()); + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); } #[actix_rt::test] async fn test_preflight() { - let mut cors = Cors::new() + let mut cors = Cors::default() + .allow_any_origin() .send_wildcard() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -244,24 +284,22 @@ mod test { let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"*"[..], + Some(&b"*"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); assert_eq!( - &b"3600"[..], + Some(&b"3600"[..]), resp.headers() .get(header::ACCESS_CONTROL_MAX_AGE) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); + let hdr = resp .headers() .get(header::ACCESS_CONTROL_ALLOW_HEADERS) - .unwrap() - .to_str() + .map(val_as_str) .unwrap(); assert!(hdr.contains("authorization")); assert!(hdr.contains("accept")); diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index dc205908a..9d4738287 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -1,15 +1,12 @@ //! Cross-Origin Resource Sharing (CORS) controls for Actix Web. //! -//! This middleware can be applied to both applications and resources. Once built, -//! [`CorsFactory`] can be used as a parameter for actix-web `App::wrap()`, +//! This middleware can be applied to both applications and resources. Once built, a +//! [`Cors`] builder can be used as an argument for Actix Web's `App::wrap()`, //! `Scope::wrap()`, or `Resource::wrap()` methods. //! //! This CORS middleware automatically handles `OPTIONS` preflight requests. //! //! # Example -//! -//! In this example a custom CORS middleware is registered for the "/index.html" endpoint. -//! //! ```rust,no_run //! use actix_cors::Cors; //! use actix_web::{get, http, web, App, HttpRequest, HttpResponse, HttpServer}; @@ -22,7 +19,7 @@ //! #[actix_web::main] //! async fn main() -> std::io::Result<()> { //! HttpServer::new(|| { -//! let cors = Cors::new() +//! let cors = Cors::default() //! .allowed_origin("https://www.rust-lang.org/") //! .allowed_origin_fn(|req| { //! req.headers @@ -34,8 +31,7 @@ //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600) -//! .finish(); +//! .max_age(3600); //! //! App::new() //! .wrap(cors) @@ -61,8 +57,8 @@ mod error; mod inner; mod middleware; -pub use all_or_some::AllOrSome; -pub use builder::{Cors, CorsFactory}; +use all_or_some::AllOrSome; +pub use builder::Cors; pub use error::CorsError; -use inner::{cors, Inner, OriginFn}; +use inner::{Inner, OriginFn}; pub use middleware::CorsMiddleware; diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index 86aa11c09..a071717a9 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -14,6 +14,7 @@ use actix_web::{ HttpResponse, }; use futures_util::future::{ok, Either, FutureExt as _, LocalBoxFuture, Ready}; +use log::debug; use crate::Inner; @@ -28,6 +29,93 @@ pub struct CorsMiddleware { pub(crate) inner: Rc, } +impl CorsMiddleware { + fn handle_preflight(inner: &Inner, req: ServiceRequest) -> ServiceResponse { + if let Err(err) = inner + .validate_origin(req.head()) + .and_then(|_| inner.validate_allowed_method(req.head())) + .and_then(|_| inner.validate_allowed_headers(req.head())) + { + return req.error_response(err); + } + + let mut res = HttpResponse::Ok(); + + if let Some(origin) = inner.access_control_allow_origin(req.head()) { + res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } + + if let Some(ref allowed_methods) = inner.allowed_methods_baked { + res.header( + header::ACCESS_CONTROL_ALLOW_METHODS, + allowed_methods.clone(), + ); + } + + if let Some(ref headers) = inner.allowed_headers_baked { + res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()); + } else if let Some(headers) = + req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) + { + // all headers allowed, return + res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()); + } + + if inner.supports_credentials { + res.header( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + + if let Some(max_age) = inner.max_age { + res.header(header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()); + } + + let res = res.finish(); + let res = res.into_body(); + req.into_response(res) + } + + fn augment_response( + inner: &Inner, + mut res: ServiceResponse, + ) -> ServiceResponse { + if let Some(origin) = inner.access_control_allow_origin(res.request().head()) { + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }; + + if let Some(ref expose) = inner.expose_headers_baked { + res.headers_mut() + .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone()); + } + + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + + if inner.vary_header { + let value = match res.headers_mut().get(header::VARY) { + Some(hdr) => { + let mut val: Vec = Vec::with_capacity(hdr.len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + val.try_into().unwrap() + } + None => HeaderValue::from_static("Origin"), + }; + + res.headers_mut().insert(header::VARY, value); + } + + res + } +} + type CorsMiddlewareServiceFuture = Either< Ready, Error>>, LocalBoxFuture<'static, Result, Error>>, @@ -49,123 +137,84 @@ where } fn call(&mut self, req: ServiceRequest) -> Self::Future { - if self.inner.preflight && Method::OPTIONS == *req.method() { - if let Err(e) = self - .inner - .validate_origin(req.head()) - .and_then(|_| self.inner.validate_allowed_method(req.head())) - .and_then(|_| self.inner.validate_allowed_headers(req.head())) - { - return Either::Left(ok(req.error_response(e))); - } - - // allowed headers - let headers = if let Some(headers) = self.inner.headers.as_ref() { - Some( - headers - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..] - .try_into() - .unwrap(), - ) - } else if let Some(hdr) = - req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) - { - Some(hdr.clone()) - } else { - None - }; - - let res = HttpResponse::Ok() - .if_some(self.inner.max_age.as_ref(), |max_age, resp| { - let _ = resp.header( - header::ACCESS_CONTROL_MAX_AGE, - format!("{}", max_age).as_str(), - ); - }) - .if_some(headers, |headers, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }) - .if_some( - self.inner.access_control_allow_origin(req.head()), - |origin, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }, - ) - .if_true(self.inner.supports_credentials, |resp| { - resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }) - .header( - header::ACCESS_CONTROL_ALLOW_METHODS, - &self - .inner - .methods - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ) - .finish() - .into_body(); - - Either::Left(ok(req.into_response(res))) + if self.inner.preflight && req.method() == Method::OPTIONS { + let inner = Rc::clone(&self.inner); + let res = Self::handle_preflight(&inner, req); + Either::Left(ok(res)) } else { - if req.headers().contains_key(header::ORIGIN) { + let origin = req.headers().get(header::ORIGIN).cloned(); + + if origin.is_some() { // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::Left(ok(req.error_response(e))); + if let Err(err) = self.inner.validate_origin(req.head()) { + debug!("origin validation failed; inner service is not called"); + return Either::Left(ok(req.error_response(err))); } } let inner = Rc::clone(&self.inner); - let has_origin = req.headers().contains_key(header::ORIGIN); let fut = self.service.call(req); - Either::Right( - async move { - let res = fut.await; + let res = async move { + let res = fut.await; - if has_origin { - let mut res = res?; - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }; - - if let Some(ref expose) = inner.expose_headers { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - expose.as_str().try_into().unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = - if let Some(hdr) = res.headers_mut().get(header::VARY) { - let mut val: Vec = - Vec::with_capacity(hdr.as_bytes().len() + 8); - val.extend(hdr.as_bytes()); - val.extend(b", Origin"); - val.try_into().unwrap() - } else { - HeaderValue::from_static("Origin") - }; - res.headers_mut().insert(header::VARY, value); - } - Ok(res) - } else { - res - } + if origin.is_some() { + let res = res?; + Ok(Self::augment_response(&inner, res)) + } else { + res } - .boxed_local(), - ) + } + .boxed_local(); + + Either::Right(res) } } } + +#[cfg(test)] +mod tests { + use actix_web::{ + dev::Transform, + test::{self, TestRequest}, + }; + + use super::*; + use crate::Cors; + + #[actix_rt::test] + async fn test_options_no_origin() { + // Tests case where allowed_origins is All but there are validate functions to run incase. + // In this case, origins are only allowed when the DNT header is sent. + + let mut cors = Cors::default() + .allow_any_origin() + .allowed_origin_fn(|req_head| req_head.headers().contains_key(header::DNT)) + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::get() + .header(header::ORIGIN, "http://example.com") + .to_srv_request(); + let res = cors.call(req).await.unwrap(); + assert_eq!( + None, + res.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + + let req = TestRequest::get() + .header(header::ORIGIN, "http://example.com") + .header(header::DNT, "1") + .to_srv_request(); + let res = cors.call(req).await.unwrap(); + assert_eq!( + Some(&b"http://example.com"[..]), + res.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + } +} diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs index dabb5c695..e239f68b2 100644 --- a/actix-cors/tests/tests.rs +++ b/actix-cors/tests/tests.rs @@ -10,12 +10,15 @@ use regex::bytes::Regex; use actix_cors::Cors; +fn val_as_str(val: &HeaderValue) -> &str { + val.to_str().unwrap() +} + #[actix_rt::test] #[should_panic] async fn test_wildcard_origin() { - Cors::new() + Cors::default() .allowed_origin("*") - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -23,7 +26,7 @@ async fn test_wildcard_origin() { #[actix_rt::test] async fn test_not_allowed_origin_fn() { - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(|req| { req.headers @@ -32,7 +35,6 @@ async fn test_not_allowed_origin_fn() { .filter(|b| b.ends_with(b".unknown.com")) .is_some() }) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -68,7 +70,7 @@ async fn test_not_allowed_origin_fn() { #[actix_rt::test] async fn test_allowed_origin_fn() { - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(|req| { req.headers @@ -77,7 +79,6 @@ async fn test_allowed_origin_fn() { .filter(|b| b.ends_with(b".unknown.com")) .is_some() }) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -92,8 +93,7 @@ async fn test_allowed_origin_fn() { "https://www.example.com", resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() + .map(val_as_str) .unwrap() ); @@ -114,7 +114,7 @@ async fn test_allowed_origin_fn() { #[actix_rt::test] async fn test_allowed_origin_fn_with_environment() { let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(move |req| { req.headers @@ -123,7 +123,6 @@ async fn test_allowed_origin_fn_with_environment() { .filter(|b| regex.is_match(b)) .is_some() }) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -138,8 +137,7 @@ async fn test_allowed_origin_fn_with_environment() { "https://www.example.com", resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() + .map(val_as_str) .unwrap() ); @@ -159,11 +157,10 @@ async fn test_allowed_origin_fn_with_environment() { #[actix_rt::test] async fn test_multiple_origins_preflight() { - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -175,11 +172,10 @@ async fn test_multiple_origins_preflight() { let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"https://example.com"[..], + Some(&b"https://example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); let req = TestRequest::with_header("Origin", "https://example.org") @@ -189,21 +185,19 @@ async fn test_multiple_origins_preflight() { let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"https://example.org"[..], + Some(&b"https://example.org"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); } #[actix_rt::test] async fn test_multiple_origins() { - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -214,11 +208,10 @@ async fn test_multiple_origins() { let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"https://example.com"[..], + Some(&b"https://example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); let req = TestRequest::with_header("Origin", "https://example.org") @@ -227,18 +220,18 @@ async fn test_multiple_origins() { let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"https://example.org"[..], + Some(&b"https://example.org"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); } #[actix_rt::test] async fn test_response() { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() + let mut cors = Cors::default() + .allow_any_origin() .send_wildcard() .disable_preflight() .max_age(3600) @@ -246,7 +239,6 @@ async fn test_response() { .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -254,18 +246,16 @@ async fn test_response() { let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"*"[..], + Some(&b"*"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() + Some(&b"Origin"[..]), + resp.headers().get(header::VARY).map(HeaderValue::as_bytes) ); #[allow(clippy::needless_collect)] @@ -273,20 +263,21 @@ async fn test_response() { let headers = resp .headers() .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) - .unwrap() - .to_str() + .map(val_as_str) .unwrap() .split(',') .map(|s| s.trim()) .collect::>(); + // TODO: use HashSet subset check for h in exposed_headers { assert!(headers.contains(&h.as_str())); } } let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() + let mut cors = Cors::default() + .allow_any_origin() .send_wildcard() .disable_preflight() .max_age(3600) @@ -294,28 +285,28 @@ async fn test_response() { .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) - .finish() .new_transform(fn_service(|req: ServiceRequest| { - ok(req.into_response( - HttpResponse::Ok().header(header::VARY, "Accept").finish(), - )) + ok(req.into_response({ + HttpResponse::Ok().header(header::VARY, "Accept").finish() + })) })) .await .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"Accept, Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() + Some(&b"Accept, Origin"[..]), + resp.headers().get(header::VARY).map(HeaderValue::as_bytes) ); - let mut cors = Cors::new() + let mut cors = Cors::default() .disable_vary_header() + .allowed_methods(vec!["POST"]) .allowed_origin("https://www.example.com") .allowed_origin("https://www.google.com") - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -325,22 +316,17 @@ async fn test_response() { .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .to_srv_request(); let resp = test::call_service(&mut cors, req).await; - let origins_str = resp .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap(); - - assert_eq!("https://www.example.com", origins_str); + .map(val_as_str); + assert_eq!(Some("https://www.example.com"), origins_str); } #[actix_rt::test] async fn test_validate_origin() { - let mut cors = Cors::new() + let mut cors = Cors::default() .allowed_origin("https://www.example.com") - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -355,9 +341,8 @@ async fn test_validate_origin() { #[actix_rt::test] async fn test_no_origin_response() { - let mut cors = Cors::new() + let mut cors = Cors::permissive() .disable_preflight() - .finish() .new_transform(test::ok_service()) .await .unwrap(); @@ -374,18 +359,16 @@ async fn test_no_origin_response() { .to_srv_request(); let resp = test::call_service(&mut cors, req).await; assert_eq!( - &b"https://www.example.com"[..], + Some(&b"https://www.example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + .map(HeaderValue::as_bytes) ); } #[actix_rt::test] async fn validate_origin_allows_all_origins() { - let mut cors = Cors::new() - .finish() + let mut cors = Cors::permissive() .new_transform(test::ok_service()) .await .unwrap();