From 99fe08f3326d8955d58a455091c12ae1fb2fe108 Mon Sep 17 00:00:00 2001 From: Roman Lakhtadyr Date: Fri, 25 Sep 2020 02:36:53 +0300 Subject: [PATCH] actix-cors: Create `allowed_origin_fn` builder method (#93) Co-authored-by: ArmorDarks --- actix-cors/CHANGES.md | 1 + actix-cors/src/lib.rs | 147 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index de1bc2fb2..8c38b9c91 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -1,6 +1,7 @@ # Changes ## Unreleased - 2020-xx-xx +* Implement `allowed_origin_fn` builder method. ## 0.3.0 - 2020-09-11 diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index 50ce49739..f1655384d 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -25,6 +25,13 @@ //! .wrap( //! Cors::new() // <- Construct CORS middleware builder //! .allowed_origin("https://www.rust-lang.org/") +//! .allowed_origin_fn(|req| { +//! req.headers +//! .get(http::header::ORIGIN) +//! .map(http::HeaderValue::as_bytes) +//! .filter(|b| b.ends_with(b".rust-lang.org")) +//! .is_some() +//! }) //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) @@ -46,6 +53,7 @@ use std::collections::HashSet; use std::convert::TryFrom; +use std::fmt; use std::iter::FromIterator; use std::rc::Rc; use std::task::{Context, Poll}; @@ -186,6 +194,7 @@ impl Cors { cors: Some(Inner { origins: AllOrSome::All, origins_str: None, + origins_fns: Vec::new(), methods: HashSet::new(), headers: AllOrSome::All, expose_hdrs: None, @@ -206,6 +215,7 @@ impl Cors { let inner = Inner { origins: AllOrSome::default(), origins_str: None, + origins_fns: Vec::new(), methods: HashSet::from_iter( vec![ Method::GET, @@ -248,6 +258,10 @@ impl Cors { /// If `send_wildcard` is not set, the client's `Origin` request header /// will be echoed back in the `Access-Control-Allow-Origin` response header. /// + /// If the origin of the request doesn't match any allowed origins and at least + /// one `allowed_origin_fn` function is set, these functions will be used + /// to determinate allowed origins. + /// /// Builder panics if supplied origin is not valid uri. pub fn allowed_origin(mut self, origin: &str) -> Cors { if let Some(cors) = cors(&mut self.cors, &self.error) { @@ -268,6 +282,21 @@ impl Cors { self } + /// Determinate allowed origins by processing requests which didn't match any origins + /// specified in the `allowed_origin`. + /// + /// The function will receive a `RequestHead` of each request, which can be used + /// to determine whether it should be allowed or not. + /// + /// If the function returns `true`, the client's `Origin` request header will be echoed + /// back into the `Access-Control-Allow-Origin` response header. + pub fn allowed_origin_fn(mut self, f: fn(req: &RequestHead) -> bool) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.origins_fns.push(OriginFn { f }); + } + self + } + /// Set a list of methods which allowed origins can perform. /// /// This is the `list of methods` in the @@ -567,10 +596,21 @@ pub struct CorsMiddleware { inner: Rc, } +struct OriginFn { + f: fn(req: &RequestHead) -> bool, +} + +impl fmt::Debug for OriginFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "origin_fn") + } +} + #[derive(Debug)] struct Inner { methods: HashSet, origins: AllOrSome>, + origins_fns: Vec, origins_str: Option, headers: AllOrSome>, expose_hdrs: Option, @@ -590,6 +630,13 @@ impl Inner { AllOrSome::Some(ref allowed_origins) => allowed_origins .get(origin) .map(|_| ()) + .or_else(|| { + if self.validate_origin_fns(req) { + Some(()) + } else { + None + } + }) .ok_or(CorsError::OriginNotAllowed), }; } @@ -602,6 +649,10 @@ impl Inner { } } + fn validate_origin_fns(&self, req: &RequestHead) -> bool { + self.origins_fns.iter().any(|origin_fn| (origin_fn.f)(req)) + } + fn access_control_allow_origin(&self, req: &RequestHead) -> Option { match self.origins { AllOrSome::All => { @@ -623,6 +674,8 @@ impl Inner { }) { Some(origin.clone()) + } else if self.validate_origin_fns(req) { + Some(req.headers().get(&header::ORIGIN).unwrap().clone()) } else { Some(self.origins_str.as_ref().unwrap().clone()) } @@ -1205,4 +1258,98 @@ mod tests { .as_bytes() ); } + + #[actix_rt::test] + async fn test_allowed_origin_fn() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(|req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| b.ends_with(b".unknown.com")) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + { + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + "https://www.example.com", + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap() + ); + } + + { + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.unknown.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + } + } + + #[actix_rt::test] + async fn test_not_allowed_origin_fn() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(|req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| b.ends_with(b".unknown.com")) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + { + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.example.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + } + + { + let req = TestRequest::with_header("Origin", "https://www.known.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + None, + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + ); + } + } }