diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index 07d4863ef..2bc59e40a 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -6,14 +6,18 @@ * `CorsFactory` is removed. [#119] * The `impl Default` constructor is now overly-restrictive. [#119] * Added `Cors::permissive()` constructor that allows anything. [#119] -* Adds methods for each property to reset to a permissive state. (`allow_any_origin`, `expose_any_header`, etc.) [#119] +* Adds methods for each property to reset to a permissive state. (`allow_any_origin`, + `expose_any_header`, etc.) [#119] * Errors are now propagated with `Transform::InitError` instead of panicking. [#119] * Fixes bug where allowed origin functions are not called if `allowed_origins` is All. [#119] * `AllOrSome` is no longer public. [#119] +* Functions used for `allowed_origin_fn` now receive the Origin HeaderValue as the + first parameter. [#120] [#114]: https://github.com/actix/actix-extras/pull/114 [#118]: https://github.com/actix/actix-extras/pull/118 [#119]: https://github.com/actix/actix-extras/pull/119 +[#120]: https://github.com/actix/actix-extras/pull/120 ## 0.4.1 - 2020-10-07 diff --git a/actix-cors/examples/cors.rs b/actix-cors/examples/cors.rs index ca534b550..c2bbd11ac 100644 --- a/actix-cors/examples/cors.rs +++ b/actix-cors/examples/cors.rs @@ -14,15 +14,18 @@ async fn main() -> std::io::Result<()> { // add specific origin to allowed origin list .allowed_origin("http://project.local:8080") // allow any port on localhost - .allowed_origin_fn(|req_head| { + .allowed_origin_fn(|origin, _req_head| { + origin.as_bytes().starts_with(b"http://localhost") + + // manual alternative: // unwrapping is acceptable on the origin header since this function is // only called when it exists - req_head - .headers() - .get(header::ORIGIN) - .unwrap() - .as_bytes() - .starts_with(b"http://localhost") + // req_head + // .headers() + // .get(header::ORIGIN) + // .unwrap() + // .as_bytes() + // .starts_with(b"http://localhost") }) // set allowed methods list .allowed_methods(vec!["GET", "POST"]) diff --git a/actix-cors/src/builder.rs b/actix-cors/src/builder.rs index 20780e20a..d37d84011 100644 --- a/actix-cors/src/builder.rs +++ b/actix-cors/src/builder.rs @@ -13,7 +13,9 @@ use tinyvec::tiny_vec; use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn}; -pub(crate) fn cors<'a>( +/// Convenience for getting mut refs to inner. Cleaner than `Rc::get_mut`. +/// Additionally, always causes first error (if any) to be reported during initialization. +fn cors<'a>( inner: &'a mut Rc, err: &Option>, ) -> Option<&'a mut Inner> { @@ -178,7 +180,7 @@ impl Cors { /// into the `Access-Control-Allow-Origin` response header. pub fn allowed_origin_fn(mut self, f: F) -> Cors where - F: (Fn(&RequestHead) -> bool) + 'static, + F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static, { if let Some(cors) = cors(&mut self.inner, &self.error) { cors.allowed_origins_fns.push(OriginFn { diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs index cf031504f..ee81ff14b 100644 --- a/actix-cors/src/inner.rs +++ b/actix-cors/src/inner.rs @@ -15,13 +15,13 @@ use crate::{AllOrSome, CorsError}; #[derive(Clone)] pub(crate) struct OriginFn { - pub(crate) boxed_fn: Rc bool>, + pub(crate) boxed_fn: Rc bool>, } impl Default for OriginFn { /// Dummy default for use in tiny_vec. Do not use. fn default() -> Self { - let boxed_fn: Rc _> = Rc::new(|_req_head| false); + let boxed_fn: Rc _> = Rc::new(|_origin, _req_head| false); Self { boxed_fn } } } @@ -78,7 +78,9 @@ impl Inner { match req.headers().get(header::ORIGIN) { // origin header exists and is a string Some(origin) => { - if allowed_origins.contains(origin) || self.validate_origin_fns(req) { + if allowed_origins.contains(origin) + || self.validate_origin_fns(origin, req) + { Ok(()) } else { Err(CorsError::OriginNotAllowed) @@ -92,11 +94,11 @@ impl Inner { } } - /// Accepts origin if _ANY_ functions return true. - pub(crate) fn validate_origin_fns(&self, req: &RequestHead) -> bool { + /// Accepts origin if _ANY_ functions return true. Only called when Origin exists. + fn validate_origin_fns(&self, origin: &HeaderValue, req: &RequestHead) -> bool { self.allowed_origins_fns .iter() - .any(|origin_fn| (origin_fn.boxed_fn)(req)) + .any(|origin_fn| (origin_fn.boxed_fn)(origin, req)) } /// Only called if origin exists and always after it's validated. diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index 9d4738287..56c7f23fd 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -21,12 +21,8 @@ //! HttpServer::new(|| { //! let cors = Cors::default() //! .allowed_origin("https://www.rust-lang.org/") -//! .allowed_origin_fn(|req| { -//! req.headers -//! .get(http::header::ORIGIN) -//! .map(http::HeaderValue::as_bytes) -//! .filter(|b| b.ends_with(b".rust-lang.org")) -//! .is_some() +//! .allowed_origin_fn(|origin, _req_head| { +//! origin.as_bytes().ends_with(b".rust-lang.org") //! }) //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index a071717a9..8bd9f89b3 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -189,7 +189,11 @@ mod tests { let mut cors = Cors::default() .allow_any_origin() - .allowed_origin_fn(|req_head| req_head.headers().contains_key(header::DNT)) + .allowed_origin_fn(|origin, req_head| { + assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap()); + + req_head.headers().contains_key(header::DNT) + }) .new_transform(test::ok_service()) .await .unwrap(); diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs index e239f68b2..a8b2267e7 100644 --- a/actix-cors/tests/tests.rs +++ b/actix-cors/tests/tests.rs @@ -28,7 +28,9 @@ async fn test_wildcard_origin() { async fn test_not_allowed_origin_fn() { let mut cors = Cors::default() .allowed_origin("https://www.example.com") - .allowed_origin_fn(|req| { + .allowed_origin_fn(|origin, req| { + assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); + req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes) @@ -72,7 +74,9 @@ async fn test_not_allowed_origin_fn() { async fn test_allowed_origin_fn() { let mut cors = Cors::default() .allowed_origin("https://www.example.com") - .allowed_origin_fn(|req| { + .allowed_origin_fn(|origin, req| { + assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); + req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes) @@ -114,9 +118,12 @@ async fn test_allowed_origin_fn() { #[actix_rt::test] async fn test_allowed_origin_fn_with_environment() { let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); + let mut cors = Cors::default() .allowed_origin("https://www.example.com") - .allowed_origin_fn(move |req| { + .allowed_origin_fn(move |origin, req| { + assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); + req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes)