diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 444540d7c..9c73b0efe 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -16,11 +16,11 @@ name = "actix_cors" path = "src/lib.rs" [dependencies] -actix-web = { version = "3.0.0", default_features = false } -actix-service = "1.0.6" +actix-web = { version = "3.0.0", default-features = false } derive_more = "0.99.2" futures-util = { version = "0.3.4", default-features = false } [dev-dependencies] +actix-service = "1.0.6" actix-rt = "1.1.1" regex = "1.3.9" diff --git a/actix-cors/src/all_or_some.rs b/actix-cors/src/all_or_some.rs new file mode 100644 index 000000000..065bf29e5 --- /dev/null +++ b/actix-cors/src/all_or_some.rs @@ -0,0 +1,46 @@ +/// An enum signifying that some of type `T` is allowed, or `All` (anything is allowed). +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AllOrSome { + /// Everything is allowed. Usually equivalent to the `*` value. + All, + + /// Only some of `T` is allowed + Some(T), +} + +/// Default as `AllOrSome::All`. +impl Default for AllOrSome { + fn default() -> Self { + AllOrSome::All + } +} + +impl AllOrSome { + /// Returns whether this is an `All` variant. + pub fn is_all(&self) -> bool { + matches!(self, AllOrSome::All) + } + + /// Returns whether this is a `Some` variant. + pub fn is_some(&self) -> bool { + !self.is_all() + } + + /// Returns &T. + pub fn as_ref(&self) -> Option<&T> { + match *self { + AllOrSome::All => None, + AllOrSome::Some(ref t) => Some(t), + } + } +} + +#[cfg(test)] +#[test] +fn tests() { + assert!(AllOrSome::<()>::All.is_all()); + assert!(!AllOrSome::<()>::All.is_some()); + + assert!(!AllOrSome::Some(()).is_all()); + assert!(AllOrSome::Some(()).is_some()); +} diff --git a/actix-cors/src/builder.rs b/actix-cors/src/builder.rs new file mode 100644 index 000000000..a72fa03ab --- /dev/null +++ b/actix-cors/src/builder.rs @@ -0,0 +1,508 @@ +use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc}; + +use actix_web::{ + dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform}, + error::{Error, Result}, + http::{self, header::HeaderName, Error as HttpError, Method, Uri}, +}; +use futures_util::future::{ok, Ready}; + +use crate::{cors, AllOrSome, CorsMiddleware, Inner, OriginFn}; + +/// Builder for CORS middleware. +/// +/// To construct a CORS middleware: +/// +/// 1. Call [`Cors::new()`] to start building. +/// 2. Use any of the builder methods to customize CORS behavior. +/// 3. Call [`Cors::finish()`] to build the middleware. +/// +/// # Example +/// +/// ```rust +/// use actix_cors::{Cors, CorsFactory}; +/// use actix_web::http::header; +/// +/// let cors = Cors::new() +/// .allowed_origin("https://www.rust-lang.org") +/// .allowed_methods(vec!["GET", "POST"]) +/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) +/// .allowed_header(header::CONTENT_TYPE) +/// .max_age(3600) +/// .finish(); +/// +/// // `cors` can now be used in `App::wrap`. +/// ``` +#[derive(Debug, Default)] +pub struct Cors { + cors: Option, + methods: bool, + error: Option, + expose_headers: HashSet, +} + +impl Cors { + /// Return a new builder. + pub fn new() -> Cors { + Cors { + cors: Some(Inner { + origins: AllOrSome::All, + origins_str: None, + origins_fns: Vec::new(), + methods: HashSet::new(), + headers: AllOrSome::All, + expose_headers: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }), + methods: false, + error: None, + expose_headers: HashSet::new(), + } + } + + /// Build a CORS middleware with default settings. + pub fn default() -> CorsFactory { + let inner = Inner { + origins: AllOrSome::default(), + origins_str: None, + origins_fns: Vec::new(), + methods: HashSet::from_iter( + vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ] + .into_iter(), + ), + headers: AllOrSome::All, + expose_headers: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }; + + CorsFactory { + inner: Rc::new(inner), + } + } + + /// Add an origin that is allowed to make requests. + /// + /// By default, requests from all origins are accepted by CORS logic. This method allows to + /// specify a finite set of origins to verify the value of the `Origin` request header. + /// + /// These are `origin-or-null` types in the [Fetch Standard]. + /// + /// When this list is set, the client's `Origin` request header will be checked in a + /// case-sensitive manner. + /// + /// When all origins are allowed and `send_wildcard` is set, `*` will be sent in the + /// `Access-Control-Allow-Origin` response header. If `send_wildcard` is not set, the client's + /// `Origin` request header will be echoed back in the `Access-Control-Allow-Origin` + /// response header. + /// + /// If the origin of the request doesn't match any allowed origins and at least one + /// `allowed_origin_fn` function is set, these functions will be used to determinate + /// allowed origins. + /// + /// Builder panics if supplied origin is not valid uri. + /// + /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header + pub fn allowed_origin(mut self, origin: &str) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match TryInto::::try_into(origin) { + Ok(_) => { + if cors.origins.is_all() { + cors.origins = AllOrSome::Some(HashSet::new()); + } + + if let AllOrSome::Some(ref mut origins) = cors.origins { + origins.insert(origin.to_owned()); + } + } + + Err(e) => { + self.error = Some(e.into()); + } + } + } + + self + } + + /// Determinate allowed origins by processing requests which didn't match any origins specified + /// in the `allowed_origin`. + /// + /// The function will receive a `RequestHead` of each request, which can be used to determine + /// whether it should be allowed or not. + /// + /// If the function returns `true`, the client's `Origin` request header will be echoed back + /// into the `Access-Control-Allow-Origin` response header. + pub fn allowed_origin_fn(mut self, f: F) -> Cors + where + F: (Fn(&RequestHead) -> bool) + 'static, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.origins_fns.push(OriginFn { + boxed_fn: Box::new(f), + }); + } + + self + } + + /// Set a list of methods which allowed origins can perform. + /// + /// These will be sent in the `Access-Control-Allow-Methods` response header as specified in + /// the [Fetch Standard CORS protocol]. + /// + /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` + /// + /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol + pub fn allowed_methods(mut self, methods: U) -> Cors + where + U: IntoIterator, + M: TryInto, + >::Error: Into, + { + self.methods = true; + if let Some(cors) = cors(&mut self.cors, &self.error) { + for m in methods { + match m.try_into() { + Ok(method) => { + cors.methods.insert(method); + } + + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + + self + } + + /// Add an allowed header. + /// + /// See `Cors::allowed_headers()` for details. + pub fn allowed_header(mut self, header: H) -> Cors + where + H: TryInto, + >::Error: Into, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match header.try_into() { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + + Err(e) => self.error = Some(e.into()), + } + } + + self + } + + /// Set a list of header field names which can be used when this resource is accessed by + /// allowed origins. + /// + /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers` + /// will be echoed back in the `Access-Control-Allow-Headers` header as specified in + /// the [Fetch Standard CORS protocol]. + /// + /// Defaults to `All`. + /// + /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol + pub fn allowed_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + H: TryInto, + >::Error: Into, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + for h in headers { + match h.try_into() { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + + self + } + + /// Set a list of headers which are safe to expose to the API of a CORS API specification. + /// This corresponds to the `Access-Control-Expose-Headers` response header as specified in + /// the [Fetch Standard CORS protocol]. + /// + /// This defaults to an empty set. + /// + /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol + pub fn expose_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + H: TryInto, + >::Error: Into, + { + for h in headers { + match h.try_into() { + Ok(method) => { + self.expose_headers.insert(method); + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + + self + } + + /// Set a maximum time for which this CORS request maybe cached. + /// This value is set as the `Access-Control-Max-Age` header as specified in + /// the [Fetch Standard CORS protocol]. + /// + /// This defaults to `None` (unset). + /// + /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol + pub fn max_age(mut self, max_age: usize) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.max_age = Some(max_age) + } + + self + } + + /// Set to use wildcard origins. + /// + /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard + /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s + /// `Origin` header. + /// + /// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and + /// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result + /// in an `CorsError::CredentialsWithWildcardOrigin` error during actix launch or runtime. + /// + /// Defaults to `false`. + pub fn send_wildcard(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.send_wildcard = true + } + + self + } + + /// Allows users to make authenticated requests + /// + /// If true, injects the `Access-Control-Allow-Credentials` header in responses. This allows + /// cookies and credentials to be submitted across domains as specified in + /// the [Fetch Standard CORS protocol]. + /// + /// This option cannot be used in conjunction with an `allowed_origin` set to `All` and + /// `send_wildcards` set to `true`. + /// + /// Defaults to `false`. + /// + /// Builder panics during `finish` if credentials are allowed, but the Origin is set to `*`. + /// This is not allowed by the CORS protocol. + /// + /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol + pub fn supports_credentials(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.supports_credentials = true + } + + self + } + + /// Disable `Vary` header support. + /// + /// When enabled the header `Vary: Origin` will be returned as per the Fetch Standard + /// implementation guidelines. + /// + /// Setting this header when the `Access-Control-Allow-Origin` is dynamically generated + /// (eg. when there is more than one allowed origin, and an Origin other than '*' is returned) + /// informs CDNs and other caches that the CORS headers are dynamic, and cannot be cached. + /// + /// By default, `Vary` header support is enabled. + pub fn disable_vary_header(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.vary_header = false + } + + self + } + + /// Disable support for preflight requests. + /// + /// When enabled CORS middleware automatically handles `OPTIONS` requests. + /// This is useful for application level middleware. + /// + /// By default *preflight* support is enabled. + pub fn disable_preflight(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.preflight = false + } + + self + } + + /// Construct CORS middleware. + pub fn finish(self) -> CorsFactory { + let mut this = if !self.methods { + self.allowed_methods(vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ]) + } else { + self + }; + + if let Some(e) = this.error.take() { + panic!("{}", e); + } + + let mut cors = this.cors.take().expect("cannot reuse CorsBuilder"); + + if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { + panic!("Credentials are allowed, but the Origin is set to \"*\""); + } + + if let AllOrSome::Some(ref origins) = cors.origins { + let s = origins + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v)); + cors.origins_str = Some(s[2..].try_into().unwrap()); + } + + if !this.expose_headers.is_empty() { + cors.expose_headers = Some( + this.expose_headers + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] + .to_owned(), + ); + } + + CorsFactory { + inner: Rc::new(cors), + } + } +} + +/// Middleware for Cross-Origin Resource Sharing support. +/// +/// This struct contains the settings for CORS requests to be validated and for responses to +/// be generated. +#[derive(Debug)] +pub struct CorsFactory { + inner: Rc, +} + +impl Transform for CorsFactory +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = CorsMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CorsMiddleware { + service, + inner: Rc::clone(&self.inner), + }) + } +} + +#[cfg(test)] +mod test { + use std::convert::{Infallible, TryInto}; + + use actix_web::{ + dev::Transform, + http::{HeaderName, StatusCode}, + test::{self, TestRequest}, + }; + + use super::*; + + #[test] + #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] + fn cors_validates_illegal_allow_credentials() { + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); + } + + #[actix_rt::test] + async fn default() { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn allowed_header_try_from() { + let _cors = Cors::new().allowed_header("Content-Type"); + } + + #[actix_rt::test] + async fn allowed_header_try_into() { + struct ContentType; + + impl TryInto for ContentType { + type Error = Infallible; + + fn try_into(self) -> Result { + Ok(HeaderName::from_static("content-type")) + } + } + + let _cors = Cors::new().allowed_header(ContentType); + } +} diff --git a/actix-cors/src/error.rs b/actix-cors/src/error.rs new file mode 100644 index 000000000..a986e08c7 --- /dev/null +++ b/actix-cors/src/error.rs @@ -0,0 +1,55 @@ +use actix_web::{http::StatusCode, HttpResponse, ResponseError}; + +use derive_more::{Display, Error}; + +/// Errors that can occur when processing CORS guarded requests. +#[derive(Debug, Clone, Display, Error)] +pub enum CorsError { + /// Request header `Origin` is required but was not provided. + #[display(fmt = "Request header `Origin` is required but was not provided.")] + MissingOrigin, + + /// Request header `Origin` could not be parsed correctly. + #[display(fmt = "Request header `Origin` could not be parsed correctly.")] + BadOrigin, + + /// Request header `Access-Control-Request-Method` is required but is missing. + #[display( + fmt = "Request header `Access-Control-Request-Method` is required but is missing." + )] + MissingRequestMethod, + + /// Request header `Access-Control-Request-Method` has an invalid value. + #[display( + fmt = "Request header `Access-Control-Request-Method` has an invalid value." + )] + BadRequestMethod, + + /// Request header `Access-Control-Request-Headers` has an invalid value. + #[display( + fmt = "Request header `Access-Control-Request-Headers` has an invalid value." + )] + BadRequestHeaders, + + /// Origin is not allowed to make this request. + #[display(fmt = "Origin is not allowed to make this request.")] + OriginNotAllowed, + + /// Request method is not allowed. + #[display(fmt = "Requested method is not allowed.")] + MethodNotAllowed, + + /// One or more request headers are not allowed. + #[display(fmt = "One or more request headers are not allowed.")] + HeadersNotAllowed, +} + +impl ResponseError for CorsError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::with_body(StatusCode::BAD_REQUEST, self.to_string().into()) + } +} diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs new file mode 100644 index 000000000..1b311ed22 --- /dev/null +++ b/actix-cors/src/inner.rs @@ -0,0 +1,294 @@ +use std::{collections::HashSet, convert::TryInto, fmt}; + +use actix_web::{ + dev::RequestHead, + error::Result, + http::{ + self, + header::{self, HeaderName, HeaderValue}, + Method, + }, +}; + +use crate::{AllOrSome, CorsError}; + +pub(crate) fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + + parts.as_mut() +} + +pub(crate) struct OriginFn { + pub(crate) boxed_fn: Box bool>, +} + +impl fmt::Debug for OriginFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("origin_fn") + } +} + +#[derive(Debug)] +pub(crate) struct Inner { + pub(crate) methods: HashSet, + pub(crate) origins: AllOrSome>, + pub(crate) origins_fns: Vec, + pub(crate) origins_str: Option, + pub(crate) headers: AllOrSome>, + pub(crate) expose_headers: Option, + pub(crate) max_age: Option, + pub(crate) preflight: bool, + pub(crate) send_wildcard: bool, + pub(crate) supports_credentials: bool, + pub(crate) vary_header: bool, +} + +impl Inner { + pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(header::ORIGIN) { + if let Ok(origin) = hdr.to_str() { + return match self.origins { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => allowed_origins + .get(origin) + .map(|_| ()) + .or_else(|| { + if self.validate_origin_fns(req) { + Some(()) + } else { + None + } + }) + .ok_or(CorsError::OriginNotAllowed), + }; + } + Err(CorsError::BadOrigin) + } else { + match self.origins { + AllOrSome::All => Ok(()), + _ => Err(CorsError::MissingOrigin), + } + } + } + + pub(crate) fn validate_origin_fns(&self, req: &RequestHead) -> bool { + self.origins_fns + .iter() + .any(|origin_fn| (origin_fn.boxed_fn)(req)) + } + + pub(crate) fn access_control_allow_origin( + &self, + req: &RequestHead, + ) -> Option { + match self.origins { + AllOrSome::All => { + if self.send_wildcard { + Some(HeaderValue::from_static("*")) + } else if let Some(origin) = req.headers().get(header::ORIGIN) { + Some(origin.clone()) + } else { + None + } + } + AllOrSome::Some(ref origins) => { + if let Some(origin) = + req.headers() + .get(header::ORIGIN) + .filter(|o| match o.to_str() { + Ok(os) => origins.contains(os), + _ => false, + }) + { + Some(origin.clone()) + } else if self.validate_origin_fns(req) { + Some(req.headers().get(header::ORIGIN).unwrap().clone()) + } else { + Some(self.origins_str.as_ref().unwrap().clone()) + } + } + } + } + + pub(crate) fn validate_allowed_method( + &self, + req: &RequestHead, + ) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { + if let Ok(meth) = hdr.to_str() { + if let Ok(method) = meth.try_into() { + return self + .methods + .get(&method) + .map(|_| ()) + .ok_or(CorsError::MethodNotAllowed); + } + } + Err(CorsError::BadRequestMethod) + } else { + Err(CorsError::MissingRequestMethod) + } + } + + pub(crate) fn validate_allowed_headers( + &self, + req: &RequestHead, + ) -> Result<(), CorsError> { + match self.headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if let Some(hdr) = + req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) + { + if let Ok(headers) = hdr.to_str() { + #[allow(clippy::mutable_key_type)] // FIXME: revisit here + let mut validated_headers = HashSet::new(); + for hdr in headers.split(',') { + match hdr.trim().try_into() { + Ok(hdr) => validated_headers.insert(hdr), + Err(_) => return Err(CorsError::BadRequestHeaders), + }; + } + // `Access-Control-Request-Headers` must contain 1 or more `field-name` + if !validated_headers.is_empty() { + if !validated_headers.is_subset(allowed_headers) { + return Err(CorsError::HeadersNotAllowed); + } + return Ok(()); + } + } + Err(CorsError::BadRequestHeaders) + } else { + Ok(()) + } + } + } + } +} + +#[cfg(test)] +mod test { + use std::rc::Rc; + + use actix_web::{ + dev::Transform, + http::{header, Method, StatusCode}, + test::{self, TestRequest}, + }; + + use crate::Cors; + + #[actix_rt::test] + #[should_panic(expected = "OriginNotAllowed")] + async fn test_validate_not_allowed_origin() { + let cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + + cors.inner.validate_origin(req.head()).unwrap(); + cors.inner.validate_allowed_method(req.head()).unwrap(); + cors.inner.validate_allowed_headers(req.head()).unwrap(); + } + + #[actix_rt::test] + async fn test_preflight() { + let mut cors = Cors::new() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"3600"[..], + resp.headers() + .get(header::ACCESS_CONTROL_MAX_AGE) + .unwrap() + .as_bytes() + ); + let hdr = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_HEADERS) + .unwrap() + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) + .unwrap() + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); + + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index 70308a775..dc205908a 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -1,29 +1,28 @@ //! Cross-Origin Resource Sharing (CORS) controls for Actix Web. //! -//! This middleware can be applied to both applications and resources. -//! Once built, [`CorsFactory`](struct.CorsFactory.html) can be used as a -//! parameter for actix-web `App::wrap()`, `Resource::wrap()` or -//! `Scope::wrap()` methods. +//! This middleware can be applied to both applications and resources. Once built, +//! [`CorsFactory`] can be used as a parameter for actix-web `App::wrap()`, +//! `Scope::wrap()`, or `Resource::wrap()` methods. //! //! This CORS middleware automatically handles `OPTIONS` preflight requests. //! //! # Example //! -//! In this example a custom CORS middleware is registered for the -//! "/index.html" endpoint. +//! In this example a custom CORS middleware is registered for the "/index.html" endpoint. //! -//! ```rust +//! ```rust,no_run //! use actix_cors::Cors; -//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; +//! use actix_web::{get, http, web, App, HttpRequest, HttpResponse, HttpServer}; //! +//! #[get("/index.html")] //! async fn index(req: HttpRequest) -> &'static str { -//! "Hello world" +//! "

Hello World!

" //! } //! -//! fn main() -> std::io::Result<()> { -//! HttpServer::new(|| App::new() -//! .wrap( -//! Cors::new() // <- Construct CORS middleware builder +//! #[actix_web::main] +//! async fn main() -> std::io::Result<()> { +//! HttpServer::new(|| { +//! let cors = Cors::new() //! .allowed_origin("https://www.rust-lang.org/") //! .allowed_origin_fn(|req| { //! req.headers @@ -36,1398 +35,34 @@ //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) //! .max_age(3600) -//! .finish()) -//! .service( -//! web::resource("/index.html") -//! .route(web::get().to(index)) -//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) -//! )) -//! .bind("127.0.0.1:8080")?; +//! .finish(); +//! +//! App::new() +//! .wrap(cors) +//! .service(index) +//! }) +//! .bind(("127.0.0.1", 8080))? +//! .run() +//! .await; //! //! Ok(()) //! } //! ``` -#![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] -#![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)] - -use std::collections::HashSet; -use std::convert::TryInto; -use std::fmt; -use std::iter::FromIterator; -use std::rc::Rc; -use std::task::{Context, Poll}; - -use actix_service::{Service, Transform}; -use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; -use actix_web::error::{Error, ResponseError, Result}; -use actix_web::http::header::{self, HeaderName, HeaderValue}; -use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri}; -use actix_web::HttpResponse; -use derive_more::Display; -use futures_util::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; - -/// A set of errors that can occur as a result of processing CORS. -#[derive(Debug, Display)] -pub enum CorsError { - /// The HTTP request header `Origin` is required but was not provided - #[display( - fmt = "The HTTP request header `Origin` is required but was not provided" - )] - MissingOrigin, - /// The HTTP request header `Origin` could not be parsed correctly. - #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")] - BadOrigin, - /// The request header `Access-Control-Request-Method` is required but is - /// missing - #[display( - fmt = "The request header `Access-Control-Request-Method` is required but is missing" - )] - MissingRequestMethod, - /// The request header `Access-Control-Request-Method` has an invalid value - #[display( - fmt = "The request header `Access-Control-Request-Method` has an invalid value" - )] - BadRequestMethod, - /// The request header `Access-Control-Request-Headers` has an invalid - /// value - #[display( - fmt = "The request header `Access-Control-Request-Headers` has an invalid value" - )] - BadRequestHeaders, - /// Origin is not allowed to make this request - #[display(fmt = "Origin is not allowed to make this request")] - OriginNotAllowed, - /// Requested method is not allowed - #[display(fmt = "Requested method is not allowed")] - MethodNotAllowed, - /// One or more headers requested are not allowed - #[display(fmt = "One or more headers requested are not allowed")] - HeadersNotAllowed, -} - -impl ResponseError for CorsError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } - - fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) - } -} - -/// An enum signifying that some of type T is allowed, or `All` (everything is -/// allowed). -/// -/// `Default` is implemented for this enum and is `All`. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AllOrSome { - /// Everything is allowed. Usually equivalent to the "*" value. - All, - /// Only some of `T` is allowed - Some(T), -} - -impl Default for AllOrSome { - fn default() -> Self { - AllOrSome::All - } -} - -impl AllOrSome { - /// Returns whether this is an `All` variant - pub fn is_all(&self) -> bool { - match *self { - AllOrSome::All => true, - AllOrSome::Some(_) => false, - } - } - - /// Returns whether this is a `Some` variant - pub fn is_some(&self) -> bool { - !self.is_all() - } - - /// Returns &T - pub fn as_ref(&self) -> Option<&T> { - match *self { - AllOrSome::All => None, - AllOrSome::Some(ref t) => Some(t), - } - } -} - -/// Builder for `CorsFactory` middleware. -/// -/// To construct a [`CorsFactory`](struct.CorsFactory.html): -/// -/// 1. Call [`Cors::new()`](struct.Cors.html#method.new) to start building. -/// 2. Use any of the builder methods to customize CORS behavior. -/// 3. Call [`finish()`](struct.Cors.html#method.finish) to retrieve the -/// middleware. -/// -/// # Example -/// -/// ```rust -/// use actix_cors::Cors; -/// use actix_web::http::header; -/// -/// let cors = Cors::new() -/// .allowed_origin("https://www.rust-lang.org") -/// .allowed_methods(vec!["GET", "POST"]) -/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) -/// .allowed_header(header::CONTENT_TYPE) -/// .max_age(3600); -/// ``` -#[derive(Debug, Default)] -pub struct Cors { - cors: Option, - methods: bool, - error: Option, - expose_hdrs: HashSet, -} - -impl Cors { - /// Return a new builder. - pub fn new() -> Cors { - Cors { - cors: Some(Inner { - origins: AllOrSome::All, - origins_str: None, - origins_fns: Vec::new(), - methods: HashSet::new(), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }), - methods: false, - error: None, - expose_hdrs: HashSet::new(), - } - } - - /// Build a CORS middleware with default settings. - pub fn default() -> CorsFactory { - let inner = Inner { - origins: AllOrSome::default(), - origins_str: None, - origins_fns: Vec::new(), - methods: HashSet::from_iter( - vec![ - Method::GET, - Method::HEAD, - Method::POST, - Method::OPTIONS, - Method::PUT, - Method::PATCH, - Method::DELETE, - ] - .into_iter(), - ), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }; - CorsFactory { - inner: Rc::new(inner), - } - } - - /// Add an origin that is allowed to make requests. - /// - /// By default, requests from all origins are accepted by CORS logic. - /// This method allows to specify a finite set of origins to verify the - /// value of the `Origin` request header. - /// - /// This is the `list of origins` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// When this list is set, the client's `Origin` request header will be - /// checked in a case-sensitive manner. - /// - /// When all origins are allowed and `send_wildcard` is set, "*" will be - /// sent in the `Access-Control-Allow-Origin` response header. - /// If `send_wildcard` is not set, the client's `Origin` request header - /// will be echoed back in the `Access-Control-Allow-Origin` response header. - /// - /// If the origin of the request doesn't match any allowed origins and at least - /// one `allowed_origin_fn` function is set, these functions will be used - /// to determinate allowed origins. - /// - /// Builder panics if supplied origin is not valid uri. - pub fn allowed_origin(mut self, origin: &str) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - match TryInto::::try_into(origin) { - Ok(_) => { - if cors.origins.is_all() { - cors.origins = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut origins) = cors.origins { - origins.insert(origin.to_owned()); - } - } - Err(e) => { - self.error = Some(e.into()); - } - } - } - self - } - - /// Determinate allowed origins by processing requests which didn't match any origins - /// specified in the `allowed_origin`. - /// - /// The function will receive a `RequestHead` of each request, which can be used - /// to determine whether it should be allowed or not. - /// - /// If the function returns `true`, the client's `Origin` request header will be echoed - /// back into the `Access-Control-Allow-Origin` response header. - pub fn allowed_origin_fn(mut self, f: F) -> Cors - where - F: (Fn(&RequestHead) -> bool) + 'static, - { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.origins_fns.push(OriginFn { - boxed_fn: Box::new(f), - }); - } - self - } - - /// Set a list of methods which allowed origins can perform. - /// - /// This is the `list of methods` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` - pub fn allowed_methods(mut self, methods: U) -> Cors - where - U: IntoIterator, - M: TryInto, - >::Error: Into, - { - self.methods = true; - if let Some(cors) = cors(&mut self.cors, &self.error) { - for m in methods { - match m.try_into() { - Ok(method) => { - cors.methods.insert(method); - } - Err(e) => { - self.error = Some(e.into()); - break; - } - } - } - } - self - } - - /// Set an allowed header. - pub fn allowed_header(mut self, header: H) -> Cors - where - H: TryInto, - >::Error: Into, - { - if let Some(cors) = cors(&mut self.cors, &self.error) { - match header.try_into() { - Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut headers) = cors.headers { - headers.insert(method); - } - } - Err(e) => self.error = Some(e.into()), - } - } - self - } - - /// Set a list of header field names which can be used when - /// this resource is accessed by allowed origins. - /// - /// If `All` is set, whatever is requested by the client in - /// `Access-Control-Request-Headers` will be echoed back in the - /// `Access-Control-Allow-Headers` header. - /// - /// This is the `list of headers` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// Defaults to `All`. - pub fn allowed_headers(mut self, headers: U) -> Cors - where - U: IntoIterator, - H: TryInto, - >::Error: Into, - { - if let Some(cors) = cors(&mut self.cors, &self.error) { - for h in headers { - match h.try_into() { - Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut headers) = cors.headers { - headers.insert(method); - } - } - Err(e) => { - self.error = Some(e.into()); - break; - } - } - } - } - self - } - - /// Set a list of headers which are safe to expose to the API of a CORS API - /// specification. This corresponds to the - /// `Access-Control-Expose-Headers` response header. - /// - /// This is the `list of exposed headers` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// This defaults to an empty set. - pub fn expose_headers(mut self, headers: U) -> Cors - where - U: IntoIterator, - H: TryInto, - >::Error: Into, - { - for h in headers { - match h.try_into() { - Ok(method) => { - self.expose_hdrs.insert(method); - } - Err(e) => { - self.error = Some(e.into()); - break; - } - } - } - self - } - - /// Set a maximum time for which this CORS request maybe cached. - /// This value is set as the `Access-Control-Max-Age` header. - /// - /// This defaults to `None` (unset). - pub fn max_age(mut self, max_age: usize) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.max_age = Some(max_age) - } - self - } - - /// Set a wildcard origins - /// - /// If send wildcard is set and the `allowed_origins` parameter is `All`, a - /// wildcard `Access-Control-Allow-Origin` response header is sent, - /// rather than the request’s `Origin` header. - /// - /// This is the `supports credentials flag` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// This **CANNOT** be used in conjunction with `allowed_origins` set to - /// `All` and `allow_credentials` set to `true`. Depending on the mode - /// of usage, this will either result in an `Error:: - /// CredentialsWithWildcardOrigin` error during actix launch or runtime. - /// - /// Defaults to `false`. - pub fn send_wildcard(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.send_wildcard = true - } - self - } - - /// Allows users to make authenticated requests - /// - /// If true, injects the `Access-Control-Allow-Credentials` header in - /// responses. This allows cookies and credentials to be submitted - /// across domains. - /// - /// This option cannot be used in conjunction with an `allowed_origin` set - /// to `All` and `send_wildcards` set to `true`. - /// - /// Defaults to `false`. - /// - /// Builder panics if credentials are allowed, but the Origin is set to "*". - /// This is not allowed by W3C - pub fn supports_credentials(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.supports_credentials = true - } - self - } - - /// Disable `Vary` header support. - /// - /// When enabled the header `Vary: Origin` will be returned as per the W3 - /// implementation guidelines. - /// - /// Setting this header when the `Access-Control-Allow-Origin` is - /// dynamically generated (e.g. when there is more than one allowed - /// origin, and an Origin than '*' is returned) informs CDNs and other - /// caches that the CORS headers are dynamic, and cannot be cached. - /// - /// By default `vary` header support is enabled. - pub fn disable_vary_header(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.vary_header = false - } - self - } - - /// Disable support for preflight requests. - /// - /// When enabled CORS middleware automatically handles `OPTIONS` requests. - /// This is useful for application level middleware. - /// - /// By default *preflight* support is enabled. - pub fn disable_preflight(mut self) -> Cors { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.preflight = false - } - self - } - - /// Construct CORS middleware. - pub fn finish(self) -> CorsFactory { - let mut slf = if !self.methods { - self.allowed_methods(vec![ - Method::GET, - Method::HEAD, - Method::POST, - Method::OPTIONS, - Method::PUT, - Method::PATCH, - Method::DELETE, - ]) - } else { - self - }; - - if let Some(e) = slf.error.take() { - panic!("{}", e); - } - - let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder"); - - if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { - panic!("Credentials are allowed, but the Origin is set to \"*\""); - } - - if let AllOrSome::Some(ref origins) = cors.origins { - let s = origins - .iter() - .fold(String::new(), |s, v| format!("{}, {}", s, v)); - cors.origins_str = Some(s[2..].try_into().unwrap()); - } - - if !slf.expose_hdrs.is_empty() { - cors.expose_hdrs = Some( - slf.expose_hdrs - .iter() - .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] - .to_owned(), - ); - } - - CorsFactory { - inner: Rc::new(cors), - } - } -} - -fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; - } - parts.as_mut() -} - -/// Middleware for Cross-Origin Resource Sharing support. -/// -/// This struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -#[derive(Debug)] -pub struct CorsFactory { - inner: Rc, -} - -impl Transform for CorsFactory -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = Error; - type InitError = (); - type Transform = CorsMiddleware; - type Future = Ready>; - - fn new_transform(&self, service: S) -> Self::Future { - ok(CorsMiddleware { - service, - inner: self.inner.clone(), - }) - } -} - -/// Service wrapper for Cross-Origin Resource Sharing support. -/// -/// This struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -#[derive(Debug, Clone)] -pub struct CorsMiddleware { - service: S, - inner: Rc, -} - -struct OriginFn { - boxed_fn: Box bool>, -} - -impl fmt::Debug for OriginFn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "origin_fn") - } -} - -#[derive(Debug)] -struct Inner { - methods: HashSet, - origins: AllOrSome>, - origins_fns: Vec, - origins_str: Option, - headers: AllOrSome>, - expose_hdrs: Option, - max_age: Option, - preflight: bool, - send_wildcard: bool, - supports_credentials: bool, - vary_header: bool, -} - -impl Inner { - fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(&header::ORIGIN) { - if let Ok(origin) = hdr.to_str() { - return match self.origins { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => allowed_origins - .get(origin) - .map(|_| ()) - .or_else(|| { - if self.validate_origin_fns(req) { - Some(()) - } else { - None - } - }) - .ok_or(CorsError::OriginNotAllowed), - }; - } - Err(CorsError::BadOrigin) - } else { - match self.origins { - AllOrSome::All => Ok(()), - _ => Err(CorsError::MissingOrigin), - } - } - } - - fn validate_origin_fns(&self, req: &RequestHead) -> bool { - self.origins_fns - .iter() - .any(|origin_fn| (origin_fn.boxed_fn)(req)) - } - - fn access_control_allow_origin(&self, req: &RequestHead) -> Option { - match self.origins { - AllOrSome::All => { - if self.send_wildcard { - Some(HeaderValue::from_static("*")) - } else if let Some(origin) = req.headers().get(&header::ORIGIN) { - Some(origin.clone()) - } else { - None - } - } - AllOrSome::Some(ref origins) => { - if let Some(origin) = - req.headers() - .get(&header::ORIGIN) - .filter(|o| match o.to_str() { - Ok(os) => origins.contains(os), - _ => false, - }) - { - Some(origin.clone()) - } else if self.validate_origin_fns(req) { - Some(req.headers().get(&header::ORIGIN).unwrap().clone()) - } else { - Some(self.origins_str.as_ref().unwrap().clone()) - } - } - } - } - - fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) { - if let Ok(meth) = hdr.to_str() { - if let Ok(method) = meth.try_into() { - return self - .methods - .get(&method) - .map(|_| ()) - .ok_or(CorsError::MethodNotAllowed); - } - } - Err(CorsError::BadRequestMethod) - } else { - Err(CorsError::MissingRequestMethod) - } - } - - fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> { - match self.headers { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_headers) => { - if let Some(hdr) = - req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) - { - if let Ok(headers) = hdr.to_str() { - #[allow(clippy::mutable_key_type)] // FIXME: revisit here - let mut hdrs = HashSet::new(); - for hdr in headers.split(',') { - match hdr.trim().try_into() { - Ok(hdr) => hdrs.insert(hdr), - Err(_) => return Err(CorsError::BadRequestHeaders), - }; - } - // `Access-Control-Request-Headers` must contain 1 or more - // `field-name`. - if !hdrs.is_empty() { - if !hdrs.is_subset(allowed_headers) { - return Err(CorsError::HeadersNotAllowed); - } - return Ok(()); - } - } - Err(CorsError::BadRequestHeaders) - } else { - Ok(()) - } - } - } - } -} - -impl Service for CorsMiddleware -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = Error; - type Future = Either< - Ready>, - LocalBoxFuture<'static, Result>, - >; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } - - fn call(&mut self, req: ServiceRequest) -> Self::Future { - if self.inner.preflight && Method::OPTIONS == *req.method() { - if let Err(e) = self - .inner - .validate_origin(req.head()) - .and_then(|_| self.inner.validate_allowed_method(req.head())) - .and_then(|_| self.inner.validate_allowed_headers(req.head())) - { - return Either::Left(ok(req.error_response(e))); - } - - // allowed headers - let headers = if let Some(headers) = self.inner.headers.as_ref() { - Some( - headers - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..] - .try_into() - .unwrap(), - ) - } else if let Some(hdr) = - req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) - { - Some(hdr.clone()) - } else { - None - }; - - let res = HttpResponse::Ok() - .if_some(self.inner.max_age.as_ref(), |max_age, resp| { - let _ = resp.header( - header::ACCESS_CONTROL_MAX_AGE, - format!("{}", max_age).as_str(), - ); - }) - .if_some(headers, |headers, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }) - .if_some( - self.inner.access_control_allow_origin(req.head()), - |origin, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }, - ) - .if_true(self.inner.supports_credentials, |resp| { - resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }) - .header( - header::ACCESS_CONTROL_ALLOW_METHODS, - &self - .inner - .methods - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ) - .finish() - .into_body(); - - Either::Left(ok(req.into_response(res))) - } else { - if req.headers().contains_key(&header::ORIGIN) { - // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::Left(ok(req.error_response(e))); - } - } - - let inner = self.inner.clone(); - let has_origin = req.headers().contains_key(&header::ORIGIN); - let fut = self.service.call(req); - - Either::Right( - async move { - let res = fut.await; - - if has_origin { - let mut res = res?; - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }; - - if let Some(ref expose) = inner.expose_hdrs { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - expose.as_str().try_into().unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = if let Some(hdr) = - res.headers_mut().get(&header::VARY) - { - let mut val: Vec = - Vec::with_capacity(hdr.as_bytes().len() + 8); - val.extend(hdr.as_bytes()); - val.extend(b", Origin"); - val.try_into().unwrap() - } else { - HeaderValue::from_static("Origin") - }; - res.headers_mut().insert(header::VARY, value); - } - Ok(res) - } else { - res - } - } - .boxed_local(), - ) - } - } -} - -#[cfg(test)] -mod tests { - use actix_service::{fn_service, Transform}; - use actix_web::test::{self, TestRequest}; - use std::convert::Infallible; - - use super::*; - use regex::bytes::Regex; - - #[actix_rt::test] - async fn allowed_header_tryfrom() { - let _cors = Cors::new().allowed_header("Content-Type"); - } - - #[actix_rt::test] - async fn allowed_header_tryinto() { - struct ContentType; - - impl TryInto for ContentType { - type Error = Infallible; - - fn try_into(self) -> Result { - Ok(HeaderName::from_static("content-type")) - } - } - - let _cors = Cors::new().allowed_header(ContentType); - } - - #[actix_rt::test] - #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] - async fn cors_validates_illegal_allow_credentials() { - let _cors = Cors::new().supports_credentials().send_wildcard().finish(); - } - - #[actix_rt::test] - async fn validate_origin_allows_all_origins() { - let mut cors = Cors::new() - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::OK); - } - - #[actix_rt::test] - async fn default() { - let mut cors = Cors::default() - .new_transform(test::ok_service()) - .await - .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::OK); - } - - #[actix_rt::test] - async fn test_preflight() { - let mut cors = Cors::new() - .send_wildcard() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) - .allowed_header(header::CONTENT_TYPE) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") - .to_srv_request(); - - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); - let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") - .method(Method::OPTIONS) - .to_srv_request(); - - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"*"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - assert_eq!( - &b"3600"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_MAX_AGE) - .unwrap() - .as_bytes() - ); - let hdr = resp - .headers() - .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) - .unwrap() - .to_str() - .unwrap(); - assert!(hdr.contains("authorization")); - assert!(hdr.contains("accept")); - assert!(hdr.contains("content-type")); - - let methods = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_METHODS) - .unwrap() - .to_str() - .unwrap(); - assert!(methods.contains("POST")); - assert!(methods.contains("GET")); - assert!(methods.contains("OPTIONS")); - - Rc::get_mut(&mut cors.inner).unwrap().preflight = false; - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::OK); - } - - // #[actix_rt::test] - // #[should_panic(expected = "MissingOrigin")] - // async fn test_validate_missing_origin() { - // let cors = Cors::build() - // .allowed_origin("https://www.example.com") - // .finish(); - // let mut req = HttpRequest::default(); - // cors.start(&req).unwrap(); - // } - - #[actix_rt::test] - #[should_panic(expected = "OriginNotAllowed")] - async fn test_validate_not_allowed_origin() { - let cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .to_srv_request(); - cors.inner.validate_origin(req.head()).unwrap(); - cors.inner.validate_allowed_method(req.head()).unwrap(); - cors.inner.validate_allowed_headers(req.head()).unwrap(); - } - - #[actix_rt::test] - async fn test_validate_origin() { - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!(resp.status(), StatusCode::OK); - } - - #[actix_rt::test] - async fn test_no_origin_response() { - let mut cors = Cors::new() - .disable_preflight() - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::default().method(Method::GET).to_srv_request(); - let resp = test::call_service(&mut cors, req).await; - assert!(resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .is_none()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"https://www.example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - } - - #[actix_rt::test] - async fn test_response() { - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"*"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - #[allow(clippy::needless_collect)] - { - let headers = resp - .headers() - .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) - .unwrap() - .to_str() - .unwrap() - .split(',') - .map(|s| s.trim()) - .collect::>(); - - for h in exposed_headers { - assert!(headers.contains(&h.as_str())); - } - } - - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish() - .new_transform(fn_service(|req: ServiceRequest| { - ok(req.into_response( - HttpResponse::Ok().header(header::VARY, "Accept").finish(), - )) - })) - .await - .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"Accept, Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - let mut cors = Cors::new() - .disable_vary_header() - .allowed_origin("https://www.example.com") - .allowed_origin("https://www.google.com") - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; - - let origins_str = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap(); - - assert_eq!("https://www.example.com", origins_str); - } - - #[actix_rt::test] - async fn test_multiple_origins() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://example.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - - let req = TestRequest::with_header("Origin", "https://example.org") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - } - - #[actix_rt::test] - async fn test_multiple_origins_preflight() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - let req = TestRequest::with_header("Origin", "https://example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - - let req = TestRequest::with_header("Origin", "https://example.org") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - } - - #[actix_rt::test] - async fn test_allowed_origin_fn() { - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .allowed_origin_fn(|req| { - req.headers - .get(header::ORIGIN) - .map(HeaderValue::as_bytes) - .filter(|b| b.ends_with(b".unknown.com")) - .is_some() - }) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - { - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - "https://www.example.com", - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap() - ); - } - - { - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - Some(&b"https://www.unknown.com"[..]), - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .map(HeaderValue::as_bytes) - ); - } - } - - #[actix_rt::test] - async fn test_allowed_origin_fn_with_environment() { - let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .allowed_origin_fn(move |req| { - req.headers - .get(header::ORIGIN) - .map(HeaderValue::as_bytes) - .filter(|b| regex.is_match(b)) - .is_some() - }) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - { - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - "https://www.example.com", - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap() - ); - } - - { - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - Some(&b"https://www.unknown.com"[..]), - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .map(HeaderValue::as_bytes) - ); - } - } - - #[actix_rt::test] - async fn test_not_allowed_origin_fn() { - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .allowed_origin_fn(|req| { - req.headers - .get(header::ORIGIN) - .map(HeaderValue::as_bytes) - .filter(|b| b.ends_with(b".unknown.com")) - .is_some() - }) - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - - { - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - Some(&b"https://www.example.com"[..]), - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .map(HeaderValue::as_bytes) - ); - } - - { - let req = TestRequest::with_header("Origin", "https://www.known.com") - .method(Method::GET) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req).await; - - assert_eq!( - None, - resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - ); - } - } -} +#![forbid(unsafe_code)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(missing_docs, missing_debug_implementations)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] + +mod all_or_some; +mod builder; +mod error; +mod inner; +mod middleware; + +pub use all_or_some::AllOrSome; +pub use builder::{Cors, CorsFactory}; +pub use error::CorsError; +use inner::{cors, Inner, OriginFn}; +pub use middleware::CorsMiddleware; diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs new file mode 100644 index 000000000..4e9c2bafb --- /dev/null +++ b/actix-cors/src/middleware.rs @@ -0,0 +1,170 @@ +use std::{ + convert::TryInto, + rc::Rc, + task::{Context, Poll}, +}; + +use actix_web::{ + dev::{Service, ServiceRequest, ServiceResponse}, + error::{Error, Result}, + http::{ + header::{self, HeaderValue}, + Method, + }, + HttpResponse, +}; +use futures_util::future::{ok, Either, FutureExt as _, LocalBoxFuture, Ready}; + +use crate::Inner; + +/// Service wrapper for Cross-Origin Resource Sharing support. +/// +/// This struct contains the settings for CORS requests to be validated and for responses to +/// be generated. +#[derive(Debug, Clone)] +pub struct CorsMiddleware { + pub(crate) service: S, + pub(crate) inner: Rc, +} + +type CorsMiddlewareServiceFuture = Either< + Ready, Error>>, + LocalBoxFuture<'static, Result, Error>>, +>; + +impl Service for CorsMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = CorsMiddlewareServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + if self.inner.preflight && Method::OPTIONS == *req.method() { + if let Err(e) = self + .inner + .validate_origin(req.head()) + .and_then(|_| self.inner.validate_allowed_method(req.head())) + .and_then(|_| self.inner.validate_allowed_headers(req.head())) + { + return Either::Left(ok(req.error_response(e))); + } + + // allowed headers + let headers = if let Some(headers) = self.inner.headers.as_ref() { + Some( + headers + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..] + .try_into() + .unwrap(), + ) + } else if let Some(hdr) = + req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) + { + Some(hdr.clone()) + } else { + None + }; + + let res = HttpResponse::Ok() + .if_some(self.inner.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( + header::ACCESS_CONTROL_MAX_AGE, + format!("{}", max_age).as_str(), + ); + }) + .if_some(headers, |headers, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); + }) + .if_some( + self.inner.access_control_allow_origin(req.head()), + |origin, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }, + ) + .if_true(self.inner.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + &self + .inner + .methods + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .finish() + .into_body(); + + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } + } + + let inner = Rc::clone(&self.inner); + let has_origin = req.headers().contains_key(header::ORIGIN); + let fut = self.service.call(req); + + Either::Right( + async move { + let res = fut.await; + + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }; + + if let Some(ref expose) = inner.expose_headers { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + expose.as_str().try_into().unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = + if let Some(hdr) = res.headers_mut().get(header::VARY) { + let mut val: Vec = + Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + val.try_into().unwrap() + } else { + HeaderValue::from_static("Origin") + }; + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res + } + } + .boxed_local(), + ) + } + } +} diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs new file mode 100644 index 000000000..9ad761e38 --- /dev/null +++ b/actix-cors/tests/tests.rs @@ -0,0 +1,387 @@ +use actix_service::fn_service; +use actix_web::{ + dev::{ServiceRequest, Transform}, + http::{header, HeaderValue, Method, StatusCode}, + test::{self, TestRequest}, + HttpResponse, +}; +use futures_util::future::ok; +use regex::bytes::Regex; + +use actix_cors::Cors; + +#[actix_rt::test] +async fn test_not_allowed_origin_fn() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(|req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| b.ends_with(b".unknown.com")) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + { + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.example.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + } + + { + let req = TestRequest::with_header("Origin", "https://www.known.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + None, + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + ); + } +} + +#[actix_rt::test] +async fn test_allowed_origin_fn() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(|req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| b.ends_with(b".unknown.com")) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + "https://www.example.com", + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap() + ); + + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.unknown.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); +} + +#[actix_rt::test] +async fn test_allowed_origin_fn_with_environment() { + let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(move |req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| regex.is_match(b)) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + "https://www.example.com", + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap() + ); + + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.unknown.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); +} + +#[actix_rt::test] +async fn test_multiple_origins_preflight() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); +} + +#[actix_rt::test] +async fn test_multiple_origins() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); +} + +#[actix_rt::test] +async fn test_response() { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + #[allow(clippy::needless_collect)] + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(fn_service(|req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + let mut cors = Cors::new() + .disable_vary_header() + .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + + let origins_str = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!("https://www.example.com", origins_str); +} + +#[actix_rt::test] +async fn test_validate_origin() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); +} + +#[actix_rt::test] +async fn test_no_origin_response() { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::default().method(Method::GET).to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert!(resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://www.example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); +} + +#[actix_rt::test] +async fn validate_origin_allows_all_origins() { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = + TestRequest::with_header("Origin", "https://www.example.com").to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); +}