1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-27 09:12:57 +01:00

Add block_on_origin_mismatch option to middleware (#287)

Co-authored-by: CapableWeb <capableweb@domain.com>
Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
CapableWeb 2022-09-22 01:22:20 +02:00 committed by GitHub
parent 82a100d96c
commit 3b5682c860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 16 deletions

View File

@ -2,6 +2,9 @@
## Unreleased - 2022-xx-xx ## Unreleased - 2022-xx-xx
- Minimum supported Rust version (MSRV) is now 1.59 due to transitive `time` dependency. - Minimum supported Rust version (MSRV) is now 1.59 due to transitive `time` dependency.
- Add `Cors::block_on_origin_mismatch()` option for controlling if requests are pre-emptively rejected. [#286]
[#286]: https://github.com/actix/actix-extras/pull/286
## 0.6.2 - 2022-08-07 ## 0.6.2 - 2022-08-07

View File

@ -39,6 +39,8 @@ async fn main() -> std::io::Result<()> {
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
// set list of headers that are safe to expose // set list of headers that are safe to expose
.expose_headers(&[header::CONTENT_DISPOSITION]) .expose_headers(&[header::CONTENT_DISPOSITION])
// allow cURL/HTTPie from working without providing Origin headers
.block_on_origin_mismatch(false)
// set preflight cache TTL // set preflight cache TTL
.max_age(3600), .max_age(3600),
) )

View File

@ -102,6 +102,7 @@ impl Cors {
send_wildcard: false, send_wildcard: false,
supports_credentials: true, supports_credentials: true,
vary_header: true, vary_header: true,
block_on_origin_mismatch: true,
}; };
Cors { Cors {
@ -448,6 +449,24 @@ impl Cors {
self self
} }
/// Configures whether requests should be pre-emptively blocked on mismatched origin.
///
/// If `true`, a 400 Bad Request is returned immediately when a request fails origin validation.
///
/// If `false`, the request will be processed as normal but relevant CORS headers will not be
/// appended to the response. In this case, the browser is trusted to validate CORS headers and
/// and block requests based on pre-flight requests. Use this setting to allow cURL and other
/// non-browser HTTP clients to function as normal, no matter what `Origin` the request has.
///
/// Defaults to `true`.
pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.block_on_origin_mismatch = block
}
self
}
} }
impl Default for Cors { impl Default for Cors {
@ -474,6 +493,7 @@ impl Default for Cors {
send_wildcard: false, send_wildcard: false,
supports_credentials: false, supports_credentials: false,
vary_header: true, vary_header: true,
block_on_origin_mismatch: true,
}; };
Cors { Cors {

View File

@ -65,16 +65,19 @@ pub(crate) struct Inner {
pub(crate) send_wildcard: bool, pub(crate) send_wildcard: bool,
pub(crate) supports_credentials: bool, pub(crate) supports_credentials: bool,
pub(crate) vary_header: bool, pub(crate) vary_header: bool,
pub(crate) block_on_origin_mismatch: bool,
} }
static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new); static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new);
impl Inner { impl Inner {
pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { /// The bool returned in Ok(_) position indicates whether the `Access-Control-Allow-Origin`
/// header should be added to the response or not.
pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<bool, CorsError> {
// return early if all origins are allowed or get ref to allowed origins set // return early if all origins are allowed or get ref to allowed origins set
#[allow(clippy::mutable_key_type)] #[allow(clippy::mutable_key_type)]
let allowed_origins = match &self.allowed_origins { let allowed_origins = match &self.allowed_origins {
AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(()), AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(true),
AllOrSome::Some(allowed_origins) => allowed_origins, AllOrSome::Some(allowed_origins) => allowed_origins,
// only function origin validators are defined // only function origin validators are defined
_ => &EMPTY_ORIGIN_SET, _ => &EMPTY_ORIGIN_SET,
@ -85,9 +88,11 @@ impl Inner {
// origin header exists and is a string // origin header exists and is a string
Some(origin) => { Some(origin) => {
if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) { if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) {
Ok(()) Ok(true)
} else { } else if self.block_on_origin_mismatch {
Err(CorsError::OriginNotAllowed) Err(CorsError::OriginNotAllowed)
} else {
Ok(false)
} }
} }

View File

@ -16,7 +16,7 @@ use log::debug;
use crate::{ use crate::{
builder::intersperse_header_values, builder::intersperse_header_values,
inner::{add_vary_header, header_value_try_into_method}, inner::{add_vary_header, header_value_try_into_method},
AllOrSome, Inner, AllOrSome, CorsError, Inner,
}; };
/// Service wrapper for Cross-Origin Resource Sharing support. /// Service wrapper for Cross-Origin Resource Sharing support.
@ -60,9 +60,14 @@ impl<S> CorsMiddleware<S> {
fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse { fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse {
let inner = Rc::clone(&self.inner); let inner = Rc::clone(&self.inner);
match inner.validate_origin(req.head()) {
Ok(true) => {}
Ok(false) => return req.error_response(CorsError::OriginNotAllowed),
Err(err) => return req.error_response(err),
};
if let Err(err) = inner if let Err(err) = inner
.validate_origin(req.head()) .validate_allowed_method(req.head())
.and_then(|_| inner.validate_allowed_method(req.head()))
.and_then(|_| inner.validate_allowed_headers(req.head())) .and_then(|_| inner.validate_allowed_headers(req.head()))
{ {
return req.error_response(err); return req.error_response(err);
@ -108,11 +113,17 @@ impl<S> CorsMiddleware<S> {
req.into_response(res) req.into_response(res)
} }
fn augment_response<B>(inner: &Inner, mut res: ServiceResponse<B>) -> ServiceResponse<B> { fn augment_response<B>(
inner: &Inner,
origin_allowed: bool,
mut res: ServiceResponse<B>,
) -> ServiceResponse<B> {
if origin_allowed {
if let Some(origin) = inner.access_control_allow_origin(res.request().head()) { if let Some(origin) = inner.access_control_allow_origin(res.request().head()) {
res.headers_mut() res.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}; };
}
if let Some(ref expose) = inner.expose_headers_baked { if let Some(ref expose) = inner.expose_headers_baked {
log::trace!("exposing selected headers: {:?}", expose); log::trace!("exposing selected headers: {:?}", expose);
@ -182,8 +193,10 @@ where
} }
// only check actual requests with a origin header // only check actual requests with a origin header
if origin.is_some() { let origin_allowed = match (origin, self.inner.validate_origin(req.head())) {
if let Err(err) = self.inner.validate_origin(req.head()) { (None, _) => false,
(_, Ok(origin_allowed)) => origin_allowed,
(_, Err(err)) => {
debug!("origin validation failed; inner service is not called"); debug!("origin validation failed; inner service is not called");
let mut res = req.error_response(err); let mut res = req.error_response(err);
@ -193,14 +206,14 @@ where
return ok(res.map_into_right_body()).boxed_local(); return ok(res.map_into_right_body()).boxed_local();
} }
} };
let inner = Rc::clone(&self.inner); let inner = Rc::clone(&self.inner);
let fut = self.service.call(req); let fut = self.service.call(req);
Box::pin(async move { Box::pin(async move {
let res = fut.await; let res = fut.await;
Ok(Self::augment_response(&inner, res?).map_into_left_body()) Ok(Self::augment_response(&inner, origin_allowed, res?).map_into_left_body())
}) })
} }
} }

View File

@ -354,6 +354,54 @@ async fn test_validate_origin() {
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_web::test]
async fn test_blocks_mismatched_origin_by_default() {
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::get()
.insert_header(("Origin", "https://www.example.test"))
.to_srv_request();
let res = test::call_service(&cors, req).await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None);
assert!(res
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.is_none());
}
#[actix_web::test]
async fn test_mismatched_origin_block_turned_off() {
let cors = Cors::default()
.allow_any_method()
.allowed_origin("https://www.example.com")
.block_on_origin_mismatch(false)
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://wrong.com"))
.insert_header(("Access-Control-Request-Method", "POST"))
.to_srv_request();
let res = test::call_service(&cors, req).await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None);
let req = TestRequest::get()
.insert_header(("Origin", "https://wrong.com"))
.to_srv_request();
let res = test::call_service(&cors, req).await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None);
}
#[actix_web::test] #[actix_web::test]
async fn test_no_origin_response() { async fn test_no_origin_response() {
let cors = Cors::permissive() let cors = Cors::permissive()