diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index 02a95458e..277e24331 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -1,6 +1,7 @@ # Changes ## Unreleased - 2020-xx-xx +* Improve `allowed_origin_fn` to allow using of closures. [#110] ## 0.4.0 - 2020-09-27 @@ -9,7 +10,7 @@ [#93]: https://github.com/actix/actix-extras/pull/93 [#106]: https://github.com/actix/actix-extras/pull/106 - +[#110]: https://github.com/actix/actix-extras/pull/110 ## 0.3.0 - 2020-09-11 * Update `actix-web` dependency to 3.0.0. diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 99b41865e..d78eeec1e 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -23,3 +23,4 @@ futures-util = { version = "0.3.4", default-features = false } [dev-dependencies] actix-rt = "1.1.1" +regex = "1.3.9" diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index d09657cda..726fba607 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -290,9 +290,14 @@ impl Cors { /// /// If the function returns `true`, the client's `Origin` request header will be echoed /// back into the `Access-Control-Allow-Origin` response header. - pub fn allowed_origin_fn(mut self, f: fn(req: &RequestHead) -> bool) -> Cors { + pub fn allowed_origin_fn(mut self, f: F) -> Cors + where + F: (Fn(&RequestHead) -> bool) + 'static, + { if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.origins_fns.push(OriginFn { f }); + cors.origins_fns.push(OriginFn { + boxed_fn: Box::new(f), + }); } self } @@ -597,7 +602,7 @@ pub struct CorsMiddleware { } struct OriginFn { - f: fn(req: &RequestHead) -> bool, + boxed_fn: Box bool>, } impl fmt::Debug for OriginFn { @@ -650,7 +655,9 @@ impl Inner { } fn validate_origin_fns(&self, req: &RequestHead) -> bool { - self.origins_fns.iter().any(|origin_fn| (origin_fn.f)(req)) + self.origins_fns + .iter() + .any(|origin_fn| (origin_fn.boxed_fn)(req)) } fn access_control_allow_origin(&self, req: &RequestHead) -> Option { @@ -882,6 +889,7 @@ mod tests { use std::convert::Infallible; use super::*; + use regex::bytes::Regex; #[actix_rt::test] async fn allowed_header_tryfrom() { @@ -1328,6 +1336,56 @@ mod tests { } } + #[actix_rt::test] + async fn test_allowed_origin_fn_with_environment() { + let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .allowed_origin_fn(move |req| { + req.headers + .get(header::ORIGIN) + .map(HeaderValue::as_bytes) + .filter(|b| regex.is_match(b)) + .is_some() + }) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + { + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + "https://www.example.com", + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap() + ); + } + + { + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + + assert_eq!( + Some(&b"https://www.unknown.com"[..]), + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .map(HeaderValue::as_bytes) + ); + } + } + #[actix_rt::test] async fn test_not_allowed_origin_fn() { let mut cors = Cors::new()