From 0b1759510b1b563a93e82ea7a33ad40695986769 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Thu, 27 Oct 2022 00:01:54 +0100 Subject: [PATCH] add support for private network access cors header closes #294 --- actix-cors/CHANGES.md | 1 + actix-cors/Cargo.toml | 9 ++-- actix-cors/src/builder.rs | 36 ++++++++++--- actix-cors/src/inner.rs | 14 ++++++ actix-cors/src/lib.rs | 14 ++++-- actix-cors/src/middleware.rs | 25 +++++++++ actix-cors/tests/tests.rs | 98 +++++++++++++++++++++++++++++++++++- 7 files changed, 183 insertions(+), 14 deletions(-) diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index 8ad8ae4f1..ee3b24d5d 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -1,6 +1,7 @@ # Changes ## Unreleased - 2022-xx-xx +- Add `Cors::allow_private_network_access()` behind an unstable flag (`draft-private-network-access`). [#297] ## 0.6.3 - 2022-09-21 diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 7f1847eb5..22baadd87 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -12,9 +12,12 @@ repository = "https://github.com/actix/actix-extras.git" license = "MIT OR Apache-2.0" edition = "2018" -[lib] -name = "actix_cors" -path = "src/lib.rs" +[package.metadata.docs.rs] +rustdoc-args = ["--cfg", "docsrs"] +all-features = true + +[features] +draft-private-network-access = [] [dependencies] actix-utils = "3" diff --git a/actix-cors/src/builder.rs b/actix-cors/src/builder.rs index 4cf751e00..16d4a1578 100644 --- a/actix-cors/src/builder.rs +++ b/actix-cors/src/builder.rs @@ -101,6 +101,8 @@ impl Cors { preflight: true, send_wildcard: false, supports_credentials: true, + #[cfg(feature = "draft-private-network-access")] + allow_private_network_access: false, vary_header: true, block_on_origin_mismatch: true, }; @@ -370,7 +372,7 @@ impl Cors { /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol pub fn max_age(mut self, max_age: impl Into>) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.max_age = max_age.into() + cors.max_age = max_age.into(); } self @@ -389,7 +391,7 @@ impl Cors { /// Defaults to `false`. pub fn send_wildcard(mut self) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.send_wildcard = true + cors.send_wildcard = true; } self @@ -412,7 +414,27 @@ impl Cors { /// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol pub fn supports_credentials(mut self) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.supports_credentials = true + cors.supports_credentials = true; + } + + self + } + + /// Allow private network access. + /// + /// If true, injects the `Access-Control-Allow-Private-Network: true` header in responses if the + /// request contained the `Access-Control-Request-Private-Network: true` header. + /// + /// For more information on this behavior, see the draft [Private Network Access] spec. + /// + /// Defaults to `false`. + /// + /// [Private Network Access]: https://wicg.github.io/private-network-access + #[cfg(feature = "draft-private-network-access")] + #[cfg_attr(docsrs, doc(cfg(feature = "draft-private-network-access")))] + pub fn allow_private_network_access(mut self) -> Cors { + if let Some(cors) = cors(&mut self.inner, &self.error) { + cors.allow_private_network_access = true; } self @@ -430,7 +452,7 @@ impl Cors { /// By default, `Vary` header support is enabled. pub fn disable_vary_header(mut self) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.vary_header = false + cors.vary_header = false; } self @@ -444,7 +466,7 @@ impl Cors { /// By default *preflight* support is enabled. pub fn disable_preflight(mut self) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.preflight = false + cors.preflight = false; } self @@ -462,7 +484,7 @@ impl Cors { /// Defaults to `true`. pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors { if let Some(cors) = cors(&mut self.inner, &self.error) { - cors.block_on_origin_mismatch = block + cors.block_on_origin_mismatch = block; } self @@ -492,6 +514,8 @@ impl Default for Cors { preflight: true, send_wildcard: false, supports_credentials: false, + #[cfg(feature = "draft-private-network-access")] + allow_private_network_access: false, vary_header: true, block_on_origin_mismatch: true, }; diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs index 617bcd2ff..f5ab671cb 100644 --- a/actix-cors/src/inner.rs +++ b/actix-cors/src/inner.rs @@ -64,6 +64,8 @@ pub(crate) struct Inner { pub(crate) preflight: bool, pub(crate) send_wildcard: bool, pub(crate) supports_credentials: bool, + #[cfg(feature = "draft-private-network-access")] + pub(crate) allow_private_network_access: bool, pub(crate) vary_header: bool, pub(crate) block_on_origin_mismatch: bool, } @@ -219,8 +221,20 @@ pub(crate) fn add_vary_header(headers: &mut HeaderMap) { let mut val: Vec = Vec::with_capacity(hdr.len() + 71); val.extend(hdr.as_bytes()); val.extend(b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers"); + + #[cfg(feature = "draft-private-network-access")] + val.extend(b", Access-Control-Allow-Private-Network"); + val.try_into().unwrap() } + + #[cfg(feature = "draft-private-network-access")] + None => HeaderValue::from_static( + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, \ + Access-Control-Allow-Private-Network", + ), + + #[cfg(not(feature = "draft-private-network-access"))] None => HeaderValue::from_static( "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ), diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index 4f343b729..b8cc3697b 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -1,11 +1,16 @@ //! Cross-Origin Resource Sharing (CORS) controls for Actix Web. //! -//! This middleware can be applied to both applications and resources. Once built, a -//! [`Cors`] builder can be used as an argument for Actix Web's `App::wrap()`, -//! `Scope::wrap()`, or `Resource::wrap()` methods. +//! This middleware can be applied to both applications and resources. Once built, a [`Cors`] +//! builder can be used as an argument for Actix Web's `App::wrap()`, `Scope::wrap()`, or +//! `Resource::wrap()` methods. //! //! This CORS middleware automatically handles `OPTIONS` preflight requests. //! +//! # Crate Features +//! - `draft-private-network-access`: ⚠️ Unstable. Adds opt-in support for the [Private Network +//! Access] spec extensions. This feature is unstable since it will follow any breaking changes in +//! the draft spec until it is finalized. +//! //! # Example //! ```no_run //! use actix_cors::Cors; @@ -40,12 +45,15 @@ //! Ok(()) //! } //! ``` +//! +//! [Private Network Access]: https://wicg.github.io/private-network-access #![forbid(unsafe_code)] #![deny(rust_2018_idioms, nonstandard_style)] #![warn(future_incompatible, missing_docs, missing_debug_implementations)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] +#![cfg_attr(docsrs, feature(doc_cfg))] mod all_or_some; mod builder; diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index 552c947d6..81055d439 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -93,6 +93,18 @@ impl CorsMiddleware { res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone())); } + #[cfg(feature = "draft-private-network-access")] + if inner.allow_private_network_access + && req + .headers() + .contains_key("access-control-request-private-network") + { + res.insert_header(( + header::HeaderName::from_static("access-control-allow-private-network"), + HeaderValue::from_static("true"), + )); + } + if inner.supports_credentials { res.insert_header(( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, @@ -162,6 +174,19 @@ impl CorsMiddleware { ); } + #[cfg(feature = "draft-private-network-access")] + if inner.allow_private_network_access + && res + .request() + .headers() + .contains_key("access-control-request-private-network") + { + res.headers_mut().insert( + header::HeaderName::from_static("access-control-allow-private-network"), + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { add_vary_header(res.headers_mut()); } diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs index 2637b4950..e08c76678 100644 --- a/actix-cors/tests/tests.rs +++ b/actix-cors/tests/tests.rs @@ -264,10 +264,16 @@ async fn test_response() { .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers().get(header::VARY).map(HeaderValue::as_bytes), Some(&b"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"[..]), ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers().get(header::VARY).map(HeaderValue::as_bytes), + Some(&b"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network"[..]), + ); #[allow(clippy::needless_collect)] { @@ -311,9 +317,18 @@ async fn test_response() { .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( - resp.headers().get(header::VARY).map(HeaderValue::as_bytes), - Some(&b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers"[..]), + resp.headers() + .get(header::VARY) + .map(HeaderValue::as_bytes) + .unwrap(), + b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers().get(header::VARY).map(HeaderValue::as_bytes).unwrap(), + b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network", ); let cors = Cors::default() @@ -463,6 +478,7 @@ async fn vary_header_on_all_handled_responses() { assert!(resp .headers() .contains_key(header::ACCESS_CONTROL_ALLOW_METHODS)); + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) @@ -471,6 +487,15 @@ async fn vary_header_on_all_handled_responses() { .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network", + ); // follow-up regular request let req = TestRequest::default() @@ -479,6 +504,7 @@ async fn vary_header_on_all_handled_responses() { .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) @@ -487,6 +513,15 @@ async fn vary_header_on_all_handled_responses() { .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network", + ); let cors = Cors::default() .allow_any_method() @@ -501,6 +536,7 @@ async fn vary_header_on_all_handled_responses() { .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) @@ -509,11 +545,21 @@ async fn vary_header_on_all_handled_responses() { .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network", + ); // regular request no origin let req = TestRequest::default().method(Method::PUT).to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); + #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) @@ -522,6 +568,15 @@ async fn vary_header_on_all_handled_responses() { .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); + #[cfg(feature = "draft-private-network-access")] + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Allow-Private-Network", + ); } #[actix_web::test] @@ -578,3 +633,42 @@ async fn expose_all_request_header_values() { assert!(cd_hdr.contains("content-disposition")); assert!(cd_hdr.contains("access-control-allow-origin")); } + +#[cfg(feature = "draft-private-network-access")] +#[actix_web::test] +async fn private_network_access() { + let cors = Cors::permissive() + .allowed_origin("https://public.site") + .allow_private_network_access() + .new_transform(fn_service(|req: ServiceRequest| async move { + let res = req.into_response( + HttpResponse::Ok() + .insert_header((header::CONTENT_DISPOSITION, "test disposition")) + .finish(), + ); + + Ok(res) + })) + .await + .unwrap(); + + let req = TestRequest::default() + .insert_header((header::ORIGIN, "https://public.site")) + .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) + .insert_header((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true")) + .to_srv_request(); + let res = test::call_service(&cors, req).await; + assert!(res.headers().contains_key("access-control-allow-origin")); + + let req = TestRequest::default() + .insert_header((header::ORIGIN, "https://public.site")) + .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) + .insert_header((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true")) + .insert_header(("Access-Control-Request-Private-Network", "true")) + .to_srv_request(); + let res = test::call_service(&cors, req).await; + assert!(res.headers().contains_key("access-control-allow-origin")); + assert!(res + .headers() + .contains_key("access-control-allow-private-network")); +}