From 695800c9bda0d9f8e7e3b76a1675a94b89163ff9 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 7 Feb 2022 16:34:01 +0000 Subject: [PATCH] add vary header to all handled responses (#224) --- actix-cors/CHANGES.md | 3 ++ actix-cors/examples/cors.rs | 10 ++++- actix-cors/src/error.rs | 6 ++- actix-cors/src/inner.rs | 54 +++++++++++++++++++++++- actix-cors/src/middleware.rs | 31 ++++---------- actix-cors/tests/tests.rs | 81 +++++++++++++++++++++++++++++++++++- 6 files changed, 156 insertions(+), 29 deletions(-) diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index ca0a08f2b..122c6b41f 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -1,6 +1,9 @@ # Changes ## Unreleased - 2021-xx-xx +- Ensure that preflight responses contain a Vary header. [#224] + +[#224]: https://github.com/actix/actix-extras/pull/224 ## 0.6.0-beta.9 - 2022-02-07 diff --git a/actix-cors/examples/cors.rs b/actix-cors/examples/cors.rs index 343ddf425..b099f8fa7 100644 --- a/actix-cors/examples/cors.rs +++ b/actix-cors/examples/cors.rs @@ -1,12 +1,16 @@ use actix_cors::Cors; -use actix_web::{http::header, web, App, HttpServer}; +use actix_web::{http::header, middleware::Logger, web, App, HttpServer}; #[actix_web::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); + log::info!("starting HTTP server at http://localhost:8080"); + HttpServer::new(move || { App::new() + // `permissive` is a wide-open development config + // .wrap(Cors::permissive()) .wrap( // default settings are overly restrictive to reduce chance of // misconfiguration leading to security concerns @@ -38,9 +42,11 @@ async fn main() -> std::io::Result<()> { // set preflight cache TTL .max_age(3600), ) + .wrap(Logger::default()) .default_service(web::to(|| async { "Hello, cross-origin world!" })) }) - .bind("127.0.0.1:8080")? + .workers(1) + .bind(("127.0.0.1", 8080))? .run() .await } diff --git a/actix-cors/src/error.rs b/actix-cors/src/error.rs index 5da672c81..7abd67044 100644 --- a/actix-cors/src/error.rs +++ b/actix-cors/src/error.rs @@ -2,6 +2,8 @@ use actix_web::{http::StatusCode, HttpResponse, ResponseError}; use derive_more::{Display, Error}; +use crate::inner::add_vary_header; + /// Errors that can occur when processing CORS guarded requests. #[derive(Debug, Clone, Display, Error)] #[non_exhaustive] @@ -45,6 +47,8 @@ impl ResponseError for CorsError { } fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::BAD_REQUEST, self.to_string()).map_into_boxed_body() + let mut res = HttpResponse::with_body(self.status_code(), self.to_string()); + add_vary_header(res.headers_mut()); + res.map_into_boxed_body() } } diff --git a/actix-cors/src/inner.rs b/actix-cors/src/inner.rs index 55ebf7c31..fb5756750 100644 --- a/actix-cors/src/inner.rs +++ b/actix-cors/src/inner.rs @@ -4,7 +4,7 @@ use actix_web::{ dev::RequestHead, error::Result, http::{ - header::{self, HeaderName, HeaderValue}, + header::{self, HeaderMap, HeaderName, HeaderValue}, Method, }, }; @@ -199,6 +199,25 @@ impl Inner { } } +/// Add CORS related request headers to response's Vary header. +/// +/// See . +pub(crate) fn add_vary_header(headers: &mut HeaderMap) { + let value = match headers.get(header::VARY) { + Some(hdr) => { + let mut val: Vec = Vec::with_capacity(hdr.len() + 71); + val.extend(hdr.as_bytes()); + val.extend(b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers"); + val.try_into().unwrap() + } + None => HeaderValue::from_static( + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ), + }; + + headers.insert(header::VARY, value); +} + #[cfg(test)] mod test { use std::rc::Rc; @@ -327,4 +346,37 @@ mod test { let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } + + #[actix_web::test] + async fn allow_fn_origin_equals_head_origin() { + let cors = Cors::default() + .allowed_origin_fn(|origin, head| { + let head_origin = head + .headers() + .get(header::ORIGIN) + .expect("unwrapping origin header should never fail in allowed_origin_fn"); + assert!(origin == head_origin); + true + }) + .allow_any_method() + .allow_any_header() + .new_transform(test::simple_service(StatusCode::NO_CONTENT)) + .await + .unwrap(); + + let req = TestRequest::default() + .method(Method::OPTIONS) + .insert_header(("Origin", "https://www.example.com")) + .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) + .to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::default() + .method(Method::GET) + .insert_header(("Origin", "https://www.example.com")) + .to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + } } diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index 1f01d997e..f6030db9e 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, convert::TryInto, rc::Rc}; +use std::{collections::HashSet, rc::Rc}; use actix_utils::future::ok; use actix_web::{ @@ -13,7 +13,7 @@ use actix_web::{ use futures_util::future::{FutureExt as _, LocalBoxFuture}; use log::debug; -use crate::{builder::intersperse_header_values, AllOrSome, Inner}; +use crate::{builder::intersperse_header_values, inner::add_vary_header, AllOrSome, Inner}; /// Service wrapper for Cross-Origin Resource Sharing support. /// @@ -67,7 +67,9 @@ impl CorsMiddleware { res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string())); } - let res = res.finish(); + let mut res = res.finish(); + add_vary_header(res.headers_mut()); + req.into_response(res) } @@ -116,21 +118,7 @@ impl CorsMiddleware { } if inner.vary_header { - let value = match res.headers_mut().get(header::VARY) { - Some(hdr) => { - let mut val: Vec = Vec::with_capacity(hdr.len() + 71); - val.extend(hdr.as_bytes()); - val.extend( - b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - ); - val.try_into().unwrap() - } - None => HeaderValue::from_static( - "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", - ), - }; - - res.headers_mut().insert(header::VARY, value); + add_vary_header(res.headers_mut()); } res @@ -172,12 +160,7 @@ where async move { let res = fut.await; - if origin.is_some() { - Ok(Self::augment_response(&inner, res?)) - } else { - res.map_err(Into::into) - } - .map(|res| res.map_into_left_body()) + Ok(Self::augment_response(&inner, res?).map_into_left_body()) } .boxed_local() } diff --git a/actix-cors/tests/tests.rs b/actix-cors/tests/tests.rs index 3b1a52664..8ff8b098e 100644 --- a/actix-cors/tests/tests.rs +++ b/actix-cors/tests/tests.rs @@ -314,8 +314,8 @@ async fn test_response() { .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( + resp.headers().get(header::VARY).map(HeaderValue::as_bytes), Some(&b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers"[..]), - resp.headers().get(header::VARY).map(HeaderValue::as_bytes) ); let cors = Cors::default() @@ -399,6 +399,85 @@ async fn validate_origin_allows_all_origins() { assert_eq!(resp.status(), StatusCode::OK); } +#[actix_web::test] +async fn vary_header_on_all_handled_responses() { + let cors = Cors::permissive() + .new_transform(test::ok_service()) + .await + .unwrap(); + + // preflight request + let req = TestRequest::default() + .method(Method::OPTIONS) + .insert_header((header::ORIGIN, "https://www.example.com")) + .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET")) + .to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp + .headers() + .contains_key(header::ACCESS_CONTROL_ALLOW_METHODS)); + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ); + + // follow-up regular request + let req = TestRequest::default() + .method(Method::PUT) + .insert_header((header::ORIGIN, "https://www.example.com")) + .to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ); + + let cors = Cors::default() + .allow_any_method() + .new_transform(test::ok_service()) + .await + .unwrap(); + + // regular request bad origin + let req = TestRequest::default() + .method(Method::PUT) + .insert_header((header::ORIGIN, "https://www.example.com")) + .to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ); + + // regular request no origin + let req = TestRequest::default().method(Method::PUT).to_srv_request(); + let resp = test::call_service(&cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers() + .get(header::VARY) + .expect("response should have Vary header") + .to_str() + .unwrap(), + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", + ); +} + #[actix_web::test] async fn test_allow_any_origin_any_method_any_header() { let cors = Cors::default()