1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-23 23:51:06 +01:00

conditionally add vary header to errors

This commit is contained in:
Rob Ede 2022-03-07 16:52:25 +00:00
parent 6fbe2eab94
commit b748e7e3a7
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
3 changed files with 30 additions and 23 deletions

View File

@ -8,13 +8,9 @@
## 0.6.0 - 2022-02-25 ## 0.6.0 - 2022-02-25
- Update `actix-web` dependency to 4.0. - Update `actix-web` dependency to 4.0.
- Ensure that preflight responses contain a `Vary` header. [#224]
[#224]: https://github.com/actix/actix-extras/pull/224
## 0.6.0-beta.10 - 2022-02-07 ## 0.6.0-beta.10 - 2022-02-07
- Ensure that preflight responses contain a Vary header. [#224] - Ensure that preflight responses contain a `Vary` header. [#224]
[#224]: https://github.com/actix/actix-extras/pull/224 [#224]: https://github.com/actix/actix-extras/pull/224

View File

@ -2,8 +2,6 @@ use actix_web::{http::StatusCode, HttpResponse, ResponseError};
use derive_more::{Display, Error}; use derive_more::{Display, Error};
use crate::inner::add_vary_header;
/// Errors that can occur when processing CORS guarded requests. /// Errors that can occur when processing CORS guarded requests.
#[derive(Debug, Clone, Display, Error)] #[derive(Debug, Clone, Display, Error)]
#[non_exhaustive] #[non_exhaustive]
@ -47,8 +45,6 @@ impl ResponseError for CorsError {
} }
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
let mut res = HttpResponse::with_body(self.status_code(), self.to_string()); HttpResponse::with_body(self.status_code(), self.to_string()).map_into_boxed_body()
add_vary_header(res.headers_mut());
res.map_into_boxed_body()
} }
} }

View File

@ -31,6 +31,7 @@ pub struct CorsMiddleware<S> {
} }
impl<S> CorsMiddleware<S> { impl<S> CorsMiddleware<S> {
/// Returns true if request is `OPTIONS` and contains an `Access-Control-Request-Method` header.
fn is_request_preflight(req: &ServiceRequest) -> bool { fn is_request_preflight(req: &ServiceRequest) -> bool {
// check request method is OPTIONS // check request method is OPTIONS
if req.method() != Method::OPTIONS { if req.method() != Method::OPTIONS {
@ -50,7 +51,15 @@ impl<S> CorsMiddleware<S> {
true true
} }
fn handle_preflight(inner: &Inner, req: ServiceRequest) -> ServiceResponse { /// Validates preflight request headers against configuration and constructs preflight response.
///
/// Checks:
/// - `Origin` header is acceptable;
/// - `Access-Control-Request-Method` header is acceptable;
/// - `Access-Control-Request-Headers` header is acceptable.
fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse {
let inner = Rc::clone(&self.inner);
if let Err(err) = inner if let Err(err) = inner
.validate_origin(req.head()) .validate_origin(req.head())
.and_then(|_| inner.validate_allowed_method(req.head())) .and_then(|_| inner.validate_allowed_method(req.head()))
@ -91,7 +100,10 @@ impl<S> CorsMiddleware<S> {
} }
let mut res = res.finish(); let mut res = res.finish();
if inner.vary_header {
add_vary_header(res.headers_mut()); add_vary_header(res.headers_mut());
}
req.into_response(res) req.into_response(res)
} }
@ -162,31 +174,35 @@ where
forward_ready!(service); forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let origin = req.headers().get(header::ORIGIN);
// handle preflight requests
if self.inner.preflight && Self::is_request_preflight(&req) { if self.inner.preflight && Self::is_request_preflight(&req) {
let inner = Rc::clone(&self.inner); let res = self.handle_preflight(req);
let res = Self::handle_preflight(&inner, req);
return ok(res.map_into_right_body()).boxed_local(); return ok(res.map_into_right_body()).boxed_local();
} }
let origin = req.headers().get(header::ORIGIN).cloned(); // only check actual requests with a origin header
if origin.is_some() { if origin.is_some() {
// Only check requests with a origin header.
if let Err(err) = self.inner.validate_origin(req.head()) { if let Err(err) = self.inner.validate_origin(req.head()) {
debug!("origin validation failed; inner service is not called"); debug!("origin validation failed; inner service is not called");
return ok(req.error_response(err).map_into_right_body()).boxed_local(); let mut res = req.error_response(err);
if self.inner.vary_header {
add_vary_header(res.headers_mut());
}
return ok(res.map_into_right_body()).boxed_local();
} }
} }
let inner = Rc::clone(&self.inner); let inner = Rc::clone(&self.inner);
let fut = self.service.call(req); let fut = self.service.call(req);
async move { Box::pin(async move {
let res = fut.await; let res = fut.await;
Ok(Self::augment_response(&inner, res?).map_into_left_body()) Ok(Self::augment_response(&inner, res?).map_into_left_body())
} })
.boxed_local()
} }
} }
@ -216,7 +232,6 @@ mod tests {
.allow_any_origin() .allow_any_origin()
.allowed_origin_fn(|origin, req_head| { .allowed_origin_fn(|origin, req_head| {
assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap()); assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap());
req_head.headers().contains_key(header::DNT) req_head.headers().contains_key(header::DNT)
}) })
.new_transform(test::ok_service()) .new_transform(test::ok_service())