diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index e9c267050..3fb902758 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -8,13 +8,9 @@ ## 0.6.0 - 2022-02-25 - Update `actix-web` dependency to 4.0. -- Ensure that preflight responses contain a `Vary` header. [#224] - -[#224]: https://github.com/actix/actix-extras/pull/224 - ## 0.6.0-beta.10 - 2022-02-07 -- Ensure that preflight responses contain a Vary header. [#224] +- Ensure that preflight responses contain a `Vary` header. [#224] [#224]: https://github.com/actix/actix-extras/pull/224 diff --git a/actix-cors/src/error.rs b/actix-cors/src/error.rs index 7abd67044..9f4130f9e 100644 --- a/actix-cors/src/error.rs +++ b/actix-cors/src/error.rs @@ -2,8 +2,6 @@ use actix_web::{http::StatusCode, HttpResponse, ResponseError}; use derive_more::{Display, Error}; -use crate::inner::add_vary_header; - /// Errors that can occur when processing CORS guarded requests. #[derive(Debug, Clone, Display, Error)] #[non_exhaustive] @@ -47,8 +45,6 @@ impl ResponseError for CorsError { } fn error_response(&self) -> HttpResponse { - let mut res = HttpResponse::with_body(self.status_code(), self.to_string()); - add_vary_header(res.headers_mut()); - res.map_into_boxed_body() + HttpResponse::with_body(self.status_code(), self.to_string()).map_into_boxed_body() } } diff --git a/actix-cors/src/middleware.rs b/actix-cors/src/middleware.rs index 832e579c9..c65e9ee6f 100644 --- a/actix-cors/src/middleware.rs +++ b/actix-cors/src/middleware.rs @@ -31,6 +31,7 @@ pub struct CorsMiddleware { } impl CorsMiddleware { + /// Returns true if request is `OPTIONS` and contains an `Access-Control-Request-Method` header. fn is_request_preflight(req: &ServiceRequest) -> bool { // check request method is OPTIONS if req.method() != Method::OPTIONS { @@ -50,7 +51,15 @@ impl CorsMiddleware { true } - fn handle_preflight(inner: &Inner, req: ServiceRequest) -> ServiceResponse { + /// Validates preflight request headers against configuration and constructs preflight response. + /// + /// Checks: + /// - `Origin` header is acceptable; + /// - `Access-Control-Request-Method` header is acceptable; + /// - `Access-Control-Request-Headers` header is acceptable. + fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse { + let inner = Rc::clone(&self.inner); + if let Err(err) = inner .validate_origin(req.head()) .and_then(|_| inner.validate_allowed_method(req.head())) @@ -91,7 +100,10 @@ impl CorsMiddleware { } let mut res = res.finish(); - add_vary_header(res.headers_mut()); + + if inner.vary_header { + add_vary_header(res.headers_mut()); + } req.into_response(res) } @@ -162,31 +174,35 @@ where forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { + let origin = req.headers().get(header::ORIGIN); + + // handle preflight requests if self.inner.preflight && Self::is_request_preflight(&req) { - let inner = Rc::clone(&self.inner); - let res = Self::handle_preflight(&inner, req); + let res = self.handle_preflight(req); return ok(res.map_into_right_body()).boxed_local(); } - let origin = req.headers().get(header::ORIGIN).cloned(); - + // only check actual requests with a origin header if origin.is_some() { - // Only check requests with a origin header. if let Err(err) = self.inner.validate_origin(req.head()) { debug!("origin validation failed; inner service is not called"); - return ok(req.error_response(err).map_into_right_body()).boxed_local(); + let mut res = req.error_response(err); + + if self.inner.vary_header { + add_vary_header(res.headers_mut()); + } + + return ok(res.map_into_right_body()).boxed_local(); } } let inner = Rc::clone(&self.inner); let fut = self.service.call(req); - async move { + Box::pin(async move { let res = fut.await; - Ok(Self::augment_response(&inner, res?).map_into_left_body()) - } - .boxed_local() + }) } } @@ -216,7 +232,6 @@ mod tests { .allow_any_origin() .allowed_origin_fn(|origin, req_head| { assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap()); - req_head.headers().contains_key(header::DNT) }) .new_transform(test::ok_service())