1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-23 23:51:06 +01:00
actix-extras/actix-cors/src/middleware.rs

234 lines
7.8 KiB
Rust
Raw Normal View History

use std::{collections::HashSet, convert::TryInto, error::Error as StdError, rc::Rc};
2021-12-11 17:05:21 +01:00
use actix_utils::future::ok;
use actix_web::{
2021-12-11 17:05:21 +01:00
body::{EitherBody, MessageBody},
dev::{Service, ServiceRequest, ServiceResponse},
error::{Error, Result},
http::{
header::{self, HeaderValue},
Method,
},
HttpResponse,
};
2021-12-11 17:05:21 +01:00
use futures_util::future::{FutureExt as _, LocalBoxFuture};
2020-10-19 06:51:31 +02:00
use log::debug;
use crate::{builder::intersperse_header_values, AllOrSome, Inner};
/// Service wrapper for Cross-Origin Resource Sharing support.
///
/// This struct contains the settings for CORS requests to be validated and for responses to
/// be generated.
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct CorsMiddleware<S> {
pub(crate) service: S,
pub(crate) inner: Rc<Inner>,
}
2020-10-19 06:51:31 +02:00
impl<S> CorsMiddleware<S> {
2021-06-27 08:02:38 +02:00
fn handle_preflight(inner: &Inner, req: ServiceRequest) -> ServiceResponse {
2020-10-19 06:51:31 +02:00
if let Err(err) = inner
.validate_origin(req.head())
.and_then(|_| inner.validate_allowed_method(req.head()))
.and_then(|_| inner.validate_allowed_headers(req.head()))
{
return req.error_response(err);
}
let mut res = HttpResponse::Ok();
if let Some(origin) = inner.access_control_allow_origin(req.head()) {
2021-03-21 23:50:26 +01:00
res.insert_header((header::ACCESS_CONTROL_ALLOW_ORIGIN, origin));
2020-10-19 06:51:31 +02:00
}
if let Some(ref allowed_methods) = inner.allowed_methods_baked {
2021-03-21 23:50:26 +01:00
res.insert_header((
2020-10-19 06:51:31 +02:00
header::ACCESS_CONTROL_ALLOW_METHODS,
allowed_methods.clone(),
2021-03-21 23:50:26 +01:00
));
2020-10-19 06:51:31 +02:00
}
if let Some(ref headers) = inner.allowed_headers_baked {
2021-03-21 23:50:26 +01:00
res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
2021-08-31 00:27:44 +02:00
} else if let Some(headers) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
2020-10-19 06:51:31 +02:00
// all headers allowed, return
2021-03-21 23:50:26 +01:00
res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
2020-10-19 06:51:31 +02:00
}
if inner.supports_credentials {
2021-03-21 23:50:26 +01:00
res.insert_header((
2020-10-19 06:51:31 +02:00
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
2021-03-21 23:50:26 +01:00
));
2020-10-19 06:51:31 +02:00
}
if let Some(max_age) = inner.max_age {
2021-03-21 23:50:26 +01:00
res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()));
2020-10-19 06:51:31 +02:00
}
let res = res.finish();
req.into_response(res)
}
2021-08-31 00:27:44 +02:00
fn augment_response<B>(inner: &Inner, mut res: ServiceResponse<B>) -> ServiceResponse<B> {
2020-10-19 06:51:31 +02:00
if let Some(origin) = inner.access_control_allow_origin(res.request().head()) {
res.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
};
if let Some(ref expose) = inner.expose_headers_baked {
log::trace!("exposing selected headers: {:?}", expose);
2020-10-19 06:51:31 +02:00
res.headers_mut()
.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
} else if matches!(inner.expose_headers, AllOrSome::All) {
// intersperse_header_values requires that argument is non-empty
if !res.request().headers().is_empty() {
// extract header names from request
let expose_all_request_headers = res
.request()
.headers()
.keys()
.into_iter()
.map(|name| name.as_str())
.collect::<HashSet<_>>();
// create comma separated string of header names
let expose_headers_value = intersperse_header_values(&expose_all_request_headers);
log::trace!(
"exposing all headers from request: {:?}",
expose_headers_value
);
// add header names to expose response header
res.headers_mut()
.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers_value);
}
2020-10-19 06:51:31 +02:00
}
if inner.supports_credentials {
res.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if inner.vary_header {
let value = match res.headers_mut().get(header::VARY) {
Some(hdr) => {
let mut val: Vec<u8> = Vec::with_capacity(hdr.len() + 8);
val.extend(hdr.as_bytes());
val.extend(b", Origin");
val.try_into().unwrap()
}
None => HeaderValue::from_static("Origin"),
};
res.headers_mut().insert(header::VARY, value);
}
res
}
}
2021-03-21 23:50:26 +01:00
impl<S, B> Service<ServiceRequest> for CorsMiddleware<S>
where
2021-03-21 23:50:26 +01:00
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
2021-06-27 08:02:38 +02:00
B: MessageBody + 'static,
B::Error: StdError,
{
2021-12-11 17:05:21 +01:00
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
2021-12-11 17:05:21 +01:00
type Future = LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>;
actix_service::forward_ready!(service);
2021-03-21 23:50:26 +01:00
fn call(&self, req: ServiceRequest) -> Self::Future {
2020-10-19 06:51:31 +02:00
if self.inner.preflight && req.method() == Method::OPTIONS {
let inner = Rc::clone(&self.inner);
let res = Self::handle_preflight(&inner, req);
2021-12-11 17:05:21 +01:00
ok(res.map_into_right_body()).boxed_local()
} else {
2020-10-19 06:51:31 +02:00
let origin = req.headers().get(header::ORIGIN).cloned();
if origin.is_some() {
// Only check requests with a origin header.
2020-10-19 06:51:31 +02:00
if let Err(err) = self.inner.validate_origin(req.head()) {
debug!("origin validation failed; inner service is not called");
2021-12-11 17:05:21 +01:00
return ok(req.error_response(err).map_into_right_body()).boxed_local();
}
}
let inner = Rc::clone(&self.inner);
let fut = self.service.call(req);
2021-12-11 17:05:21 +01:00
async move {
2020-10-19 06:51:31 +02:00
let res = fut.await;
if origin.is_some() {
Ok(Self::augment_response(&inner, res?))
2020-10-19 06:51:31 +02:00
} else {
res
}
2021-12-11 17:05:21 +01:00
.map(|res| res.map_into_left_body())
2020-10-19 06:51:31 +02:00
}
2021-12-11 17:05:21 +01:00
.boxed_local()
}
}
}
2020-10-19 06:51:31 +02:00
#[cfg(test)]
mod tests {
use actix_web::{
dev::Transform,
test::{self, TestRequest},
};
use super::*;
use crate::Cors;
#[actix_rt::test]
async fn test_options_no_origin() {
// Tests case where allowed_origins is All but there are validate functions to run incase.
// In this case, origins are only allowed when the DNT header is sent.
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
2020-10-19 06:51:31 +02:00
.allow_any_origin()
.allowed_origin_fn(|origin, req_head| {
assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap());
req_head.headers().contains_key(header::DNT)
})
2020-10-19 06:51:31 +02:00
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::get()
2021-03-21 23:50:26 +01:00
.insert_header((header::ORIGIN, "http://example.com"))
2020-10-19 06:51:31 +02:00
.to_srv_request();
let res = cors.call(req).await.unwrap();
assert_eq!(
None,
res.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
let req = TestRequest::get()
2021-03-21 23:50:26 +01:00
.insert_header((header::ORIGIN, "http://example.com"))
.insert_header((header::DNT, "1"))
2020-10-19 06:51:31 +02:00
.to_srv_request();
let res = cors.call(req).await.unwrap();
assert_eq!(
Some(&b"http://example.com"[..]),
res.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
}