use std::{collections::HashSet, rc::Rc}; use actix_utils::future::ok; use actix_web::{ body::{EitherBody, MessageBody}, dev::{forward_ready, Service, ServiceRequest, ServiceResponse}, http::{ header::{self, HeaderValue}, Method, }, Error, HttpResponse, Result, }; use futures_util::future::{FutureExt as _, LocalBoxFuture}; use log::debug; use crate::{ builder::intersperse_header_values, inner::{add_vary_header, header_value_try_into_method}, AllOrSome, CorsError, Inner, }; /// Service wrapper for Cross-Origin Resource Sharing support. /// /// This struct contains the settings for CORS requests to be validated and for responses to /// be generated. #[doc(hidden)] #[derive(Debug, Clone)] pub struct CorsMiddleware { pub(crate) service: S, pub(crate) inner: Rc, } impl CorsMiddleware { /// Returns true if request is `OPTIONS` and contains an `Access-Control-Request-Method` header. fn is_request_preflight(req: &ServiceRequest) -> bool { // check request method is OPTIONS if req.method() != Method::OPTIONS { return false; } // check follow-up request method is present and valid if req .headers() .get(header::ACCESS_CONTROL_REQUEST_METHOD) .and_then(header_value_try_into_method) .is_none() { return false; } true } /// Validates preflight request headers against configuration and constructs preflight response. /// /// Checks: /// - `Origin` header is acceptable; /// - `Access-Control-Request-Method` header is acceptable; /// - `Access-Control-Request-Headers` header is acceptable. fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse { let inner = Rc::clone(&self.inner); match inner.validate_origin(req.head()) { Ok(true) => {} Ok(false) => return req.error_response(CorsError::OriginNotAllowed), Err(err) => return req.error_response(err), }; if let Err(err) = inner .validate_allowed_method(req.head()) .and_then(|_| inner.validate_allowed_headers(req.head())) { return req.error_response(err); } let mut res = HttpResponse::Ok(); if let Some(origin) = inner.access_control_allow_origin(req.head()) { res.insert_header((header::ACCESS_CONTROL_ALLOW_ORIGIN, origin)); } if let Some(ref allowed_methods) = inner.allowed_methods_baked { res.insert_header(( header::ACCESS_CONTROL_ALLOW_METHODS, allowed_methods.clone(), )); } if let Some(ref headers) = inner.allowed_headers_baked { res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone())); } else if let Some(headers) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { // all headers allowed, return res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone())); } #[cfg(feature = "draft-local-network-access")] if inner.allow_local_network_access && req .headers() .contains_key("access-control-request-local-network") { res.insert_header(( header::HeaderName::from_static("access-control-allow-local-network"), HeaderValue::from_static("true"), )); } if inner.supports_credentials { res.insert_header(( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"), )); } if let Some(max_age) = inner.max_age { res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string())); } let mut res = res.finish(); if inner.vary_header { add_vary_header(res.headers_mut()); } req.into_response(res) } fn augment_response( inner: &Inner, origin_allowed: bool, mut res: ServiceResponse, ) -> ServiceResponse { if origin_allowed { if let Some(origin) = inner.access_control_allow_origin(res.request().head()) { res.headers_mut() .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); }; } if let Some(ref expose) = inner.expose_headers_baked { log::trace!("exposing selected headers: {:?}", expose); res.headers_mut() .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone()); } else if matches!(inner.expose_headers, AllOrSome::All) { // intersperse_header_values requires that argument is non-empty if !res.headers().is_empty() { // extract header names from request let expose_all_request_headers = res .headers() .keys() .map(|name| name.as_str()) .collect::>(); // create comma separated string of header names let expose_headers_value = intersperse_header_values(&expose_all_request_headers); log::trace!( "exposing all headers from request: {:?}", expose_headers_value ); // add header names to expose response header res.headers_mut() .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers_value); } } if inner.supports_credentials { res.headers_mut().insert( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"), ); } #[cfg(feature = "draft-local-network-access")] if inner.allow_local_network_access && res .request() .headers() .contains_key("access-control-request-local-network") { res.headers_mut().insert( header::HeaderName::from_static("access-control-allow-local-network"), HeaderValue::from_static("true"), ); } if inner.vary_header { add_vary_header(res.headers_mut()); } res } } impl Service for CorsMiddleware where S: Service, Error = Error>, S::Future: 'static, B: MessageBody + 'static, { type Response = ServiceResponse>; type Error = Error; type Future = LocalBoxFuture<'static, Result>, Error>>; forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let origin = req.headers().get(header::ORIGIN); // handle preflight requests if self.inner.preflight && Self::is_request_preflight(&req) { let res = self.handle_preflight(req); return ok(res.map_into_right_body()).boxed_local(); } // only check actual requests with a origin header let origin_allowed = match (origin, self.inner.validate_origin(req.head())) { (None, _) => false, (_, Ok(origin_allowed)) => origin_allowed, (_, Err(err)) => { debug!("origin validation failed; inner service is not called"); let mut res = req.error_response(err); if self.inner.vary_header { add_vary_header(res.headers_mut()); } return ok(res.map_into_right_body()).boxed_local(); } }; let inner = Rc::clone(&self.inner); let fut = self.service.call(req); Box::pin(async move { let res = fut.await; Ok(Self::augment_response(&inner, origin_allowed, res?).map_into_left_body()) }) } } #[cfg(test)] mod tests { use actix_web::{ dev::Transform, middleware::Compat, test::{self, TestRequest}, App, }; use super::*; use crate::Cors; #[test] fn compat_compat() { let _ = App::new().wrap(Compat::new(Cors::default())); } #[actix_web::test] async fn test_options_no_origin() { // Tests case where allowed_origins is All but there are validate functions to run incase. // In this case, origins are only allowed when the DNT header is sent. let cors = Cors::default() .allow_any_origin() .allowed_origin_fn(|origin, req_head| { assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap()); req_head.headers().contains_key(header::DNT) }) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header((header::ORIGIN, "http://example.com")) .to_srv_request(); let res = cors.call(req).await.unwrap(); assert_eq!( None, res.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); let req = TestRequest::get() .insert_header((header::ORIGIN, "http://example.com")) .insert_header((header::DNT, "1")) .to_srv_request(); let res = cors.call(req).await.unwrap(); assert_eq!( Some(&b"http://example.com"[..]), res.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } }