From 3b5682c8609d6c521fb3a564730f97ca01160024 Mon Sep 17 00:00:00 2001 From: CapableWeb <112625454+CapableWeb@users.noreply.github.com> Date: Thu, 22 Sep 2022 01:22:20 +0200 Subject: [PATCH] Add block_on_origin_mismatch option to middleware (#287) Co-authored-by: CapableWeb Co-authored-by: Rob Ede --- actix-cors/CHANGES.md | 3 +++ actix-cors/examples/cors.rs | 2 ++ actix-cors/src/builder.rs | 20 +++++++++++++++ actix-cors/src/inner.rs | 13 +++++++--- actix-cors/src/middleware.rs | 37 ++++++++++++++++++--------- actix-cors/tests/tests.rs | 48 ++++++++++++++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 16 deletions(-) diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index e19b7847b..78c5d999f 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -2,6 +2,9 @@ ## Unreleased - 2022-xx-xx - Minimum supported Rust version (MSRV) is now 1.59 due to transitive `time` dependency. +- Add `Cors::block_on_origin_mismatch()` option for controlling if requests are pre-emptively rejected. [#286] + +[#286]: https://github.com/actix/actix-extras/pull/286 ## 0.6.2 - 2022-08-07 diff --git a/actix-cors/examples/cors.rs b/actix-cors/examples/cors.rs index b099f8fa7..91e0c972f 100644 --- a/actix-cors/examples/cors.rs +++ b/actix-cors/examples/cors.rs @@ -39,6 +39,8 @@ async fn main() -> std::io::Result<()> { .allowed_header(header::CONTENT_TYPE) // set list of headers that are safe to expose .expose_headers(&[header::CONTENT_DISPOSITION]) + // allow cURL/HTTPie from working without providing Origin headers + .block_on_origin_mismatch(false) // set preflight cache TTL .max_age(3600), ) diff --git a/actix-cors/src/builder.rs b/actix-cors/src/builder.rs index 7ffba714f..4cf751e00 100644 --- a/actix-cors/src/builder.rs +++ b/actix-cors/src/builder.rs @@ -102,6 +102,7 @@ impl Cors { send_wildcard: false, supports_credentials: true, vary_header: true, + block_on_origin_mismatch: true, }; Cors { @@ -448,6 +449,24 @@ impl Cors { self } + + /// Configures whether requests should be pre-emptively blocked on mismatched origin. + /// + /// If `true`, a 400 Bad Request is returned immediately when a request fails origin validation. + /// + /// If `false`, the request will be processed as normal but relevant CORS headers will not be + /// appended to the response. In this case, the browser is trusted to validate CORS headers and + /// and block requests based on pre-flight requests. Use this setting to allow cURL and other + /// non-browser HTTP clients to function as normal, no matter what `Origin` the request has. + /// + /// Defaults to `true`. + pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.block_on_origin_mismatch = block + } + + self + } } impl Default for Cors { @@ -474,6 +493,7 @@ impl Default for Cors { send_wildcard: false, supports_credentials: false, vary_header: true, + block_on_origin_mismatch: true, }; Cors { diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs index 523c1d8cc..617bcd2ff 100644 --- a/actix-cors/src/inner.rs +++ b/actix-cors/src/inner.rs @@ -65,16 +65,19 @@ pub(crate) struct Inner { pub(crate) send_wildcard: bool, pub(crate) supports_credentials: bool, pub(crate) vary_header: bool, + pub(crate) block_on_origin_mismatch: bool, } static EMPTY_ORIGIN_SET: Lazy> = Lazy::new(HashSet::new); impl Inner { - pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { + /// The bool returned in Ok(_) position indicates whether the `Access-Control-Allow-Origin` + /// header should be added to the response or not. + pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result { // return early if all origins are allowed or get ref to allowed origins set #[allow(clippy::mutable_key_type)] let allowed_origins = match &self.allowed_origins { - AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(()), + AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(true), AllOrSome::Some(allowed_origins) => allowed_origins, // only function origin validators are defined _ => &EMPTY_ORIGIN_SET, @@ -85,9 +88,11 @@ impl Inner { // origin header exists and is a string Some(origin) => { if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) { - Ok(()) - } else { + Ok(true) + } else if self.block_on_origin_mismatch { Err(CorsError::OriginNotAllowed) + } else { + Ok(false) } } diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index fb7b487bf..552c947d6 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -16,7 +16,7 @@ use log::debug; use crate::{ builder::intersperse_header_values, inner::{add_vary_header, header_value_try_into_method}, - AllOrSome, Inner, + AllOrSome, CorsError, Inner, }; /// Service wrapper for Cross-Origin Resource Sharing support. @@ -60,9 +60,14 @@ impl CorsMiddleware { fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse { let inner = Rc::clone(&self.inner); + match inner.validate_origin(req.head()) { + Ok(true) => {} + Ok(false) => return req.error_response(CorsError::OriginNotAllowed), + Err(err) => return req.error_response(err), + }; + if let Err(err) = inner - .validate_origin(req.head()) - .and_then(|_| inner.validate_allowed_method(req.head())) + .validate_allowed_method(req.head()) .and_then(|_| inner.validate_allowed_headers(req.head())) { return req.error_response(err); @@ -108,11 +113,17 @@ impl CorsMiddleware { req.into_response(res) } - fn augment_response(inner: &Inner, mut res: ServiceResponse) -> ServiceResponse { - if let Some(origin) = inner.access_control_allow_origin(res.request().head()) { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }; + fn augment_response( + inner: &Inner, + origin_allowed: bool, + mut res: ServiceResponse, + ) -> ServiceResponse { + if origin_allowed { + if let Some(origin) = inner.access_control_allow_origin(res.request().head()) { + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }; + } if let Some(ref expose) = inner.expose_headers_baked { log::trace!("exposing selected headers: {:?}", expose); @@ -182,8 +193,10 @@ where } // only check actual requests with a origin header - if origin.is_some() { - if let Err(err) = self.inner.validate_origin(req.head()) { + let origin_allowed = match (origin, self.inner.validate_origin(req.head())) { + (None, _) => false, + (_, Ok(origin_allowed)) => origin_allowed, + (_, Err(err)) => { debug!("origin validation failed; inner service is not called"); let mut res = req.error_response(err); @@ -193,14 +206,14 @@ where return ok(res.map_into_right_body()).boxed_local(); } - } + }; let inner = Rc::clone(&self.inner); let fut = self.service.call(req); Box::pin(async move { let res = fut.await; - Ok(Self::augment_response(&inner, res?).map_into_left_body()) + Ok(Self::augment_response(&inner, origin_allowed, res?).map_into_left_body()) }) } } diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs index 1f4f94a84..2637b4950 100644 --- a/actix-cors/tests/tests.rs +++ b/actix-cors/tests/tests.rs @@ -354,6 +354,54 @@ async fn test_validate_origin() { assert_eq!(resp.status(), StatusCode::OK); } +#[actix_web::test] +async fn test_blocks_mismatched_origin_by_default() { + let cors = Cors::default() + .allowed_origin("https://www.example.com") + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::get() + .insert_header(("Origin", "https://www.example.test")) + .to_srv_request(); + + let res = test::call_service(&cors, req).await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); + assert!(res + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) + .is_none()); +} + +#[actix_web::test] +async fn test_mismatched_origin_block_turned_off() { + let cors = Cors::default() + .allow_any_method() + .allowed_origin("https://www.example.com") + .block_on_origin_mismatch(false) + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::default() + .method(Method::OPTIONS) + .insert_header(("Origin", "https://wrong.com")) + .insert_header(("Access-Control-Request-Method", "POST")) + .to_srv_request(); + let res = test::call_service(&cors, req).await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); + + let req = TestRequest::get() + .insert_header(("Origin", "https://wrong.com")) + .to_srv_request(); + let res = test::call_service(&cors, req).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); +} + #[actix_web::test] async fn test_no_origin_response() { let cors = Cors::permissive()