mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-30 10:32:55 +01:00
CORS builder rework (#119)
This commit is contained in:
parent
c8e641d4b6
commit
6084c47810
@ -2,10 +2,18 @@
|
|||||||
|
|
||||||
## Unreleased - 2020-xx-xx
|
## Unreleased - 2020-xx-xx
|
||||||
* Disallow `*` in `Cors::allowed_origin` by panicking. [#114].
|
* Disallow `*` in `Cors::allowed_origin` by panicking. [#114].
|
||||||
* Hide `CorsMiddleware` from rustdocs. [#118].
|
* Hide `CorsMiddleware` from docs. [#118].
|
||||||
|
* `CorsFactory` is removed. [#119]
|
||||||
|
* The `impl Default` constructor is now overly-restrictive. [#119]
|
||||||
|
* Added `Cors::permissive()` constructor that allows anything. [#119]
|
||||||
|
* Adds methods for each property to reset to a permissive state. (`allow_any_origin`, `expose_any_header`, etc.) [#119]
|
||||||
|
* Errors are now propagated with `Transform::InitError` instead of panicking. [#119]
|
||||||
|
* Fixes bug where allowed origin functions are not called if `allowed_origins` is All. [#119]
|
||||||
|
* `AllOrSome` is no longer public. [#119]
|
||||||
|
|
||||||
[#114]: https://github.com/actix/actix-extras/pull/114
|
[#114]: https://github.com/actix/actix-extras/pull/114
|
||||||
[#118]: https://github.com/actix/actix-extras/pull/118
|
[#118]: https://github.com/actix/actix-extras/pull/118
|
||||||
|
[#119]: https://github.com/actix/actix-extras/pull/119
|
||||||
|
|
||||||
|
|
||||||
## 0.4.1 - 2020-10-07
|
## 0.4.1 - 2020-10-07
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "actix-cors"
|
name = "actix-cors"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
|
authors = [
|
||||||
|
"Nikolay Kim <fafhrd91@gmail.com>",
|
||||||
|
"Rob Ede <robjtede@icloud.com>"
|
||||||
|
]
|
||||||
description = "Cross-Origin Resource Sharing (CORS) controls for Actix Web"
|
description = "Cross-Origin Resource Sharing (CORS) controls for Actix Web"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
keywords = ["actix", "cors", "web", "security", "crossorigin"]
|
keywords = ["actix", "cors", "web", "security", "crossorigin"]
|
||||||
@ -19,8 +22,12 @@ path = "src/lib.rs"
|
|||||||
actix-web = { version = "3.0.0", default-features = false }
|
actix-web = { version = "3.0.0", default-features = false }
|
||||||
derive_more = "0.99.2"
|
derive_more = "0.99.2"
|
||||||
futures-util = { version = "0.3.4", default-features = false }
|
futures-util = { version = "0.3.4", default-features = false }
|
||||||
|
log = "0.4"
|
||||||
|
once_cell = "1"
|
||||||
|
tinyvec = "1.0.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
actix-service = "1.0.6"
|
actix-service = "1"
|
||||||
actix-rt = "1.1.1"
|
actix-rt = "1"
|
||||||
regex = "1.3.9"
|
pretty_env_logger = "0.4"
|
||||||
|
regex = "1.4"
|
||||||
|
43
actix-cors/examples/cors.rs
Normal file
43
actix-cors/examples/cors.rs
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
use actix_cors::Cors;
|
||||||
|
use actix_web::{http::header, web, App, HttpServer};
|
||||||
|
|
||||||
|
#[actix_web::main]
|
||||||
|
async fn main() -> std::io::Result<()> {
|
||||||
|
pretty_env_logger::init();
|
||||||
|
|
||||||
|
HttpServer::new(move || {
|
||||||
|
App::new()
|
||||||
|
.wrap(
|
||||||
|
// default settings are overly restrictive to reduce chance of
|
||||||
|
// misconfiguration leading to security concerns
|
||||||
|
Cors::default()
|
||||||
|
// add specific origin to allowed origin list
|
||||||
|
.allowed_origin("http://project.local:8080")
|
||||||
|
// allow any port on localhost
|
||||||
|
.allowed_origin_fn(|req_head| {
|
||||||
|
// unwrapping is acceptable on the origin header since this function is
|
||||||
|
// only called when it exists
|
||||||
|
req_head
|
||||||
|
.headers()
|
||||||
|
.get(header::ORIGIN)
|
||||||
|
.unwrap()
|
||||||
|
.as_bytes()
|
||||||
|
.starts_with(b"http://localhost")
|
||||||
|
})
|
||||||
|
// set allowed methods list
|
||||||
|
.allowed_methods(vec!["GET", "POST"])
|
||||||
|
// set allowed request header list
|
||||||
|
.allowed_headers(&[header::AUTHORIZATION, header::ACCEPT])
|
||||||
|
// add header to allowed list
|
||||||
|
.allowed_header(header::CONTENT_TYPE)
|
||||||
|
// set list of headers that are safe to expose
|
||||||
|
.expose_headers(&[header::CONTENT_DISPOSITION])
|
||||||
|
// set CORS rules ttl
|
||||||
|
.max_age(3600),
|
||||||
|
)
|
||||||
|
.default_service(web::to(|| async { "Hello world!" }))
|
||||||
|
})
|
||||||
|
.bind("127.0.0.1:8080")?
|
||||||
|
.run()
|
||||||
|
.await
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
/// An enum signifying that some of type `T` is allowed, or `All` (anything is allowed).
|
/// An enum signifying that some of type `T` is allowed, or `All` (anything is allowed).
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
pub enum AllOrSome<T> {
|
pub enum AllOrSome<T> {
|
||||||
/// Everything is allowed. Usually equivalent to the `*` value.
|
/// Everything is allowed. Usually equivalent to the `*` value.
|
||||||
All,
|
All,
|
||||||
@ -22,17 +22,26 @@ impl<T> AllOrSome<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns whether this is a `Some` variant.
|
/// Returns whether this is a `Some` variant.
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn is_some(&self) -> bool {
|
pub fn is_some(&self) -> bool {
|
||||||
!self.is_all()
|
!self.is_all()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns &T.
|
/// Provides a shared reference to `T` if variant is `Some`.
|
||||||
pub fn as_ref(&self) -> Option<&T> {
|
pub fn as_ref(&self) -> Option<&T> {
|
||||||
match *self {
|
match *self {
|
||||||
AllOrSome::All => None,
|
AllOrSome::All => None,
|
||||||
AllOrSome::Some(ref t) => Some(t),
|
AllOrSome::Some(ref t) => Some(t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Provides a mutable reference to `T` if variant is `Some`.
|
||||||
|
pub fn as_mut(&mut self) -> Option<&mut T> {
|
||||||
|
match *self {
|
||||||
|
AllOrSome::All => None,
|
||||||
|
AllOrSome::Some(ref mut t) => Some(t),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -3,99 +3,117 @@ use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc};
|
|||||||
use actix_web::{
|
use actix_web::{
|
||||||
dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
|
dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
|
||||||
error::{Error, Result},
|
error::{Error, Result},
|
||||||
http::{self, header::HeaderName, Error as HttpError, Method, Uri},
|
http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri},
|
||||||
|
Either,
|
||||||
};
|
};
|
||||||
use futures_util::future::{ok, Ready};
|
use futures_util::future::{self, Ready};
|
||||||
|
use log::error;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tinyvec::tiny_vec;
|
||||||
|
|
||||||
use crate::{cors, AllOrSome, CorsMiddleware, Inner, OriginFn};
|
use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
|
||||||
|
|
||||||
|
pub(crate) fn cors<'a>(
|
||||||
|
inner: &'a mut Rc<Inner>,
|
||||||
|
err: &Option<Either<http::Error, CorsError>>,
|
||||||
|
) -> Option<&'a mut Inner> {
|
||||||
|
if err.is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Rc::get_mut(inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
|
||||||
|
HashSet::from_iter(vec![
|
||||||
|
Method::GET,
|
||||||
|
Method::POST,
|
||||||
|
Method::PUT,
|
||||||
|
Method::DELETE,
|
||||||
|
Method::HEAD,
|
||||||
|
Method::OPTIONS,
|
||||||
|
Method::CONNECT,
|
||||||
|
Method::PATCH,
|
||||||
|
Method::TRACE,
|
||||||
|
])
|
||||||
|
});
|
||||||
|
|
||||||
/// Builder for CORS middleware.
|
/// Builder for CORS middleware.
|
||||||
///
|
///
|
||||||
/// To construct a CORS middleware:
|
/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder.
|
||||||
|
/// Then use any of the builder methods to customize CORS behavior.
|
||||||
///
|
///
|
||||||
/// 1. Call [`Cors::new()`] to start building.
|
/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing
|
||||||
/// 2. Use any of the builder methods to customize CORS behavior.
|
/// all origins and headers, etc. **The permissive constructor should not be used in production.**
|
||||||
/// 3. Call [`Cors::finish()`] to build the middleware.
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled
|
||||||
|
/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error
|
||||||
|
/// messages will outline what is wrong with the CORS configuration in the server logs and the
|
||||||
|
/// server will fail to start up or serve requests.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use actix_cors::{Cors, CorsFactory};
|
/// use actix_cors::Cors;
|
||||||
/// use actix_web::http::header;
|
/// use actix_web::http::header;
|
||||||
///
|
///
|
||||||
/// let cors = Cors::new()
|
/// let cors = Cors::default()
|
||||||
/// .allowed_origin("https://www.rust-lang.org")
|
/// .allowed_origin("https://www.rust-lang.org")
|
||||||
/// .allowed_methods(vec!["GET", "POST"])
|
/// .allowed_methods(vec!["GET", "POST"])
|
||||||
/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
|
/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
|
||||||
/// .allowed_header(header::CONTENT_TYPE)
|
/// .allowed_header(header::CONTENT_TYPE)
|
||||||
/// .max_age(3600)
|
/// .max_age(3600);
|
||||||
/// .finish();
|
|
||||||
///
|
///
|
||||||
/// // `cors` can now be used in `App::wrap`.
|
/// // `cors` can now be used in `App::wrap`.
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug)]
|
||||||
pub struct Cors {
|
pub struct Cors {
|
||||||
cors: Option<Inner>,
|
inner: Rc<Inner>,
|
||||||
methods: bool,
|
error: Option<Either<http::Error, CorsError>>,
|
||||||
error: Option<http::Error>,
|
|
||||||
expose_headers: HashSet<HeaderName>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cors {
|
impl Cors {
|
||||||
/// Return a new builder.
|
/// A very permissive set of default for quick development. Not recommended for production use.
|
||||||
pub fn new() -> Cors {
|
///
|
||||||
Cors {
|
/// *All* origins, methods, request headers and exposed headers allowed. Credentials supported.
|
||||||
cors: Some(Inner {
|
/// Max age 1 hour. Does not send wildcard.
|
||||||
origins: AllOrSome::All,
|
pub fn permissive() -> Self {
|
||||||
origins_str: None,
|
|
||||||
origins_fns: Vec::new(),
|
|
||||||
methods: HashSet::new(),
|
|
||||||
headers: AllOrSome::All,
|
|
||||||
expose_headers: None,
|
|
||||||
max_age: None,
|
|
||||||
preflight: true,
|
|
||||||
send_wildcard: false,
|
|
||||||
supports_credentials: false,
|
|
||||||
vary_header: true,
|
|
||||||
}),
|
|
||||||
methods: false,
|
|
||||||
error: None,
|
|
||||||
expose_headers: HashSet::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build a CORS middleware with default settings.
|
|
||||||
pub fn default() -> CorsFactory {
|
|
||||||
let inner = Inner {
|
let inner = Inner {
|
||||||
origins: AllOrSome::default(),
|
allowed_origins: AllOrSome::All,
|
||||||
origins_str: None,
|
allowed_origins_fns: tiny_vec![],
|
||||||
origins_fns: Vec::new(),
|
|
||||||
methods: HashSet::from_iter(
|
allowed_methods: ALL_METHODS_SET.clone(),
|
||||||
vec![
|
allowed_methods_baked: None,
|
||||||
Method::GET,
|
|
||||||
Method::HEAD,
|
allowed_headers: AllOrSome::All,
|
||||||
Method::POST,
|
allowed_headers_baked: None,
|
||||||
Method::OPTIONS,
|
|
||||||
Method::PUT,
|
expose_headers: AllOrSome::All,
|
||||||
Method::PATCH,
|
expose_headers_baked: None,
|
||||||
Method::DELETE,
|
max_age: Some(3600),
|
||||||
]
|
|
||||||
.into_iter(),
|
|
||||||
),
|
|
||||||
headers: AllOrSome::All,
|
|
||||||
expose_headers: None,
|
|
||||||
max_age: None,
|
|
||||||
preflight: true,
|
preflight: true,
|
||||||
send_wildcard: false,
|
send_wildcard: false,
|
||||||
supports_credentials: false,
|
supports_credentials: true,
|
||||||
vary_header: true,
|
vary_header: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
CorsFactory {
|
Cors {
|
||||||
inner: Rc::new(inner),
|
inner: Rc::new(inner),
|
||||||
|
error: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resets allowed origin list to a state where any origin is accepted.
|
||||||
|
///
|
||||||
|
/// See [`Cors::allowed_origin`] for more info on allowed origins.
|
||||||
|
pub fn allow_any_origin(mut self) -> Cors {
|
||||||
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
|
cors.allowed_origins = AllOrSome::All;
|
||||||
|
}
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Add an origin that is allowed to make requests.
|
/// Add an origin that is allowed to make requests.
|
||||||
///
|
///
|
||||||
/// By default, requests from all origins are accepted by CORS logic. This method allows to
|
/// By default, requests from all origins are accepted by CORS logic. This method allows to
|
||||||
@ -115,32 +133,34 @@ impl Cors {
|
|||||||
/// `allowed_origin_fn` function is set, these functions will be used to determinate
|
/// `allowed_origin_fn` function is set, these functions will be used to determinate
|
||||||
/// allowed origins.
|
/// allowed origins.
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Initialization Errors
|
||||||
///
|
/// - If supplied origin is not valid uri
|
||||||
/// * If supplied origin is not valid uri, or
|
/// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
|
||||||
/// * If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
|
|
||||||
///
|
///
|
||||||
/// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
|
/// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
|
||||||
pub fn allowed_origin(mut self, origin: &str) -> Cors {
|
pub fn allowed_origin(mut self, origin: &str) -> Cors {
|
||||||
assert!(
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
origin != "*",
|
|
||||||
"Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`."
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
|
||||||
match TryInto::<Uri>::try_into(origin) {
|
match TryInto::<Uri>::try_into(origin) {
|
||||||
|
Ok(_) if origin == "*" => {
|
||||||
|
error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
|
||||||
|
self.error = Some(Either::B(CorsError::WildcardOrigin));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
if cors.origins.is_all() {
|
if cors.allowed_origins.is_all() {
|
||||||
cors.origins = AllOrSome::Some(HashSet::new());
|
cors.allowed_origins =
|
||||||
|
AllOrSome::Some(HashSet::with_capacity(8));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let AllOrSome::Some(ref mut origins) = cors.origins {
|
if let Some(origins) = cors.allowed_origins.as_mut() {
|
||||||
origins.insert(origin.to_owned());
|
// any uri is a valid header value
|
||||||
|
let hv = origin.try_into().unwrap();
|
||||||
|
origins.insert(hv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(e) => {
|
Err(err) => {
|
||||||
self.error = Some(e.into());
|
self.error = Some(Either::A(err.into()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,15 +180,26 @@ impl Cors {
|
|||||||
where
|
where
|
||||||
F: (Fn(&RequestHead) -> bool) + 'static,
|
F: (Fn(&RequestHead) -> bool) + 'static,
|
||||||
{
|
{
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.origins_fns.push(OriginFn {
|
cors.allowed_origins_fns.push(OriginFn {
|
||||||
boxed_fn: Box::new(f),
|
boxed_fn: Rc::new(f),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resets allowed methods list to all methods.
|
||||||
|
///
|
||||||
|
/// See [`Cors::allowed_methods`] for more info on allowed methods.
|
||||||
|
pub fn allow_any_method(mut self) -> Cors {
|
||||||
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
|
cors.allowed_methods = ALL_METHODS_SET.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set a list of methods which allowed origins can perform.
|
/// Set a list of methods which allowed origins can perform.
|
||||||
///
|
///
|
||||||
/// These will be sent in the `Access-Control-Allow-Methods` response header as specified in
|
/// These will be sent in the `Access-Control-Allow-Methods` response header as specified in
|
||||||
@ -183,16 +214,15 @@ impl Cors {
|
|||||||
M: TryInto<Method>,
|
M: TryInto<Method>,
|
||||||
<M as TryInto<Method>>::Error: Into<HttpError>,
|
<M as TryInto<Method>>::Error: Into<HttpError>,
|
||||||
{
|
{
|
||||||
self.methods = true;
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
|
||||||
for m in methods {
|
for m in methods {
|
||||||
match m.try_into() {
|
match m.try_into() {
|
||||||
Ok(method) => {
|
Ok(method) => {
|
||||||
cors.methods.insert(method);
|
cors.allowed_methods.insert(method);
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(e) => {
|
Err(err) => {
|
||||||
self.error = Some(e.into());
|
self.error = Some(Either::A(err.into()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -202,34 +232,46 @@ impl Cors {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an allowed header.
|
/// Resets allowed request header list to a state where any origin is accepted.
|
||||||
///
|
///
|
||||||
/// See `Cors::allowed_headers()` for details.
|
/// See [`Cors::allowed_headers`] for more info on allowed request headers.
|
||||||
|
pub fn allow_any_header(mut self) -> Cors {
|
||||||
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
|
cors.allowed_origins = AllOrSome::All;
|
||||||
|
}
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add an allowed request header.
|
||||||
|
///
|
||||||
|
/// See [`Cors::allowed_headers`] for more info on allowed request headers.
|
||||||
pub fn allowed_header<H>(mut self, header: H) -> Cors
|
pub fn allowed_header<H>(mut self, header: H) -> Cors
|
||||||
where
|
where
|
||||||
H: TryInto<HeaderName>,
|
H: TryInto<HeaderName>,
|
||||||
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
|
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
|
||||||
{
|
{
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
match header.try_into() {
|
match header.try_into() {
|
||||||
Ok(method) => {
|
Ok(method) => {
|
||||||
if cors.headers.is_all() {
|
if cors.allowed_headers.is_all() {
|
||||||
cors.headers = AllOrSome::Some(HashSet::new());
|
cors.allowed_headers =
|
||||||
|
AllOrSome::Some(HashSet::with_capacity(8));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let AllOrSome::Some(ref mut headers) = cors.headers {
|
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
|
||||||
headers.insert(method);
|
headers.insert(method);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(e) => self.error = Some(e.into()),
|
Err(err) => self.error = Some(Either::A(err.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set a list of header field names which can be used when this resource is accessed by
|
/// Set a list of request header field names which can be used when this resource is accessed by
|
||||||
/// allowed origins.
|
/// allowed origins.
|
||||||
///
|
///
|
||||||
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
|
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
|
||||||
@ -245,19 +287,21 @@ impl Cors {
|
|||||||
H: TryInto<HeaderName>,
|
H: TryInto<HeaderName>,
|
||||||
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
|
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
|
||||||
{
|
{
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
for h in headers {
|
for h in headers {
|
||||||
match h.try_into() {
|
match h.try_into() {
|
||||||
Ok(method) => {
|
Ok(method) => {
|
||||||
if cors.headers.is_all() {
|
if cors.allowed_headers.is_all() {
|
||||||
cors.headers = AllOrSome::Some(HashSet::new());
|
cors.allowed_headers =
|
||||||
|
AllOrSome::Some(HashSet::with_capacity(8));
|
||||||
}
|
}
|
||||||
if let AllOrSome::Some(ref mut headers) = cors.headers {
|
|
||||||
|
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
|
||||||
headers.insert(method);
|
headers.insert(method);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(err) => {
|
||||||
self.error = Some(e.into());
|
self.error = Some(Either::A(err.into()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -267,6 +311,17 @@ impl Cors {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resets exposed response header list to a state where any header is accepted.
|
||||||
|
///
|
||||||
|
/// See [`Cors::expose_headers`] for more info on exposed response headers.
|
||||||
|
pub fn expose_any_header(mut self) -> Cors {
|
||||||
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
|
cors.allowed_origins = AllOrSome::All;
|
||||||
|
}
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
|
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
|
||||||
/// This corresponds to the `Access-Control-Expose-Headers` response header as specified in
|
/// This corresponds to the `Access-Control-Expose-Headers` response header as specified in
|
||||||
/// the [Fetch Standard CORS protocol].
|
/// the [Fetch Standard CORS protocol].
|
||||||
@ -282,11 +337,19 @@ impl Cors {
|
|||||||
{
|
{
|
||||||
for h in headers {
|
for h in headers {
|
||||||
match h.try_into() {
|
match h.try_into() {
|
||||||
Ok(method) => {
|
Ok(header) => {
|
||||||
self.expose_headers.insert(method);
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
|
if cors.expose_headers.is_all() {
|
||||||
|
cors.expose_headers =
|
||||||
|
AllOrSome::Some(HashSet::with_capacity(8));
|
||||||
|
}
|
||||||
|
if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
|
||||||
|
headers.insert(header);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(err) => {
|
||||||
self.error = Some(e.into());
|
self.error = Some(Either::A(err.into()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -295,16 +358,16 @@ impl Cors {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set a maximum time for which this CORS request maybe cached.
|
/// Set a maximum time (in seconds) for which this CORS request maybe cached.
|
||||||
/// This value is set as the `Access-Control-Max-Age` header as specified in
|
/// This value is set as the `Access-Control-Max-Age` header as specified in
|
||||||
/// the [Fetch Standard CORS protocol].
|
/// the [Fetch Standard CORS protocol].
|
||||||
///
|
///
|
||||||
/// This defaults to `None` (unset).
|
/// Pass a number (of seconds) or use None to disable sending max age header.
|
||||||
///
|
///
|
||||||
/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
|
/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
|
||||||
pub fn max_age(mut self, max_age: usize) -> Cors {
|
pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.max_age = Some(max_age)
|
cors.max_age = max_age.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
@ -322,7 +385,7 @@ impl Cors {
|
|||||||
///
|
///
|
||||||
/// Defaults to `false`.
|
/// Defaults to `false`.
|
||||||
pub fn send_wildcard(mut self) -> Cors {
|
pub fn send_wildcard(mut self) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.send_wildcard = true
|
cors.send_wildcard = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,12 +403,12 @@ impl Cors {
|
|||||||
///
|
///
|
||||||
/// Defaults to `false`.
|
/// Defaults to `false`.
|
||||||
///
|
///
|
||||||
/// Builder panics during `finish` if credentials are allowed, but the Origin is set to `*`.
|
/// A server initialization error will occur if credentials are allowed, but the Origin is set
|
||||||
/// This is not allowed by the CORS protocol.
|
/// to send wildcards (`*`); this is not allowed by the CORS protocol.
|
||||||
///
|
///
|
||||||
/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
|
/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
|
||||||
pub fn supports_credentials(mut self) -> Cors {
|
pub fn supports_credentials(mut self) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.supports_credentials = true
|
cors.supports_credentials = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,7 +426,7 @@ impl Cors {
|
|||||||
///
|
///
|
||||||
/// By default, `Vary` header support is enabled.
|
/// By default, `Vary` header support is enabled.
|
||||||
pub fn disable_vary_header(mut self) -> Cors {
|
pub fn disable_vary_header(mut self) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.vary_header = false
|
cors.vary_header = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,71 +440,48 @@ impl Cors {
|
|||||||
///
|
///
|
||||||
/// By default *preflight* support is enabled.
|
/// By default *preflight* support is enabled.
|
||||||
pub fn disable_preflight(mut self) -> Cors {
|
pub fn disable_preflight(mut self) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.inner, &self.error) {
|
||||||
cors.preflight = false
|
cors.preflight = false
|
||||||
}
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Construct CORS middleware.
|
impl Default for Cors {
|
||||||
pub fn finish(self) -> CorsFactory {
|
/// A restrictive (security paranoid) set of defaults.
|
||||||
let mut this = if !self.methods {
|
///
|
||||||
self.allowed_methods(vec![
|
/// *No* allowed origins, methods, request headers or exposed headers. Credentials
|
||||||
Method::GET,
|
/// not supported. No max age (will use browser's default).
|
||||||
Method::HEAD,
|
fn default() -> Cors {
|
||||||
Method::POST,
|
let inner = Inner {
|
||||||
Method::OPTIONS,
|
allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
|
||||||
Method::PUT,
|
allowed_origins_fns: tiny_vec![],
|
||||||
Method::PATCH,
|
|
||||||
Method::DELETE,
|
allowed_methods: HashSet::with_capacity(8),
|
||||||
])
|
allowed_methods_baked: None,
|
||||||
} else {
|
|
||||||
self
|
allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
|
||||||
|
allowed_headers_baked: None,
|
||||||
|
|
||||||
|
expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
|
||||||
|
expose_headers_baked: None,
|
||||||
|
|
||||||
|
max_age: None,
|
||||||
|
preflight: true,
|
||||||
|
send_wildcard: false,
|
||||||
|
supports_credentials: false,
|
||||||
|
vary_header: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(e) = this.error.take() {
|
Cors {
|
||||||
panic!("{}", e);
|
inner: Rc::new(inner),
|
||||||
}
|
error: None,
|
||||||
|
|
||||||
let mut cors = this.cors.take().expect("cannot reuse CorsBuilder");
|
|
||||||
|
|
||||||
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
|
|
||||||
panic!("Credentials are allowed, but the Origin is set to \"*\"");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let AllOrSome::Some(ref origins) = cors.origins {
|
|
||||||
let s = origins
|
|
||||||
.iter()
|
|
||||||
.fold(String::new(), |s, v| format!("{}, {}", s, v));
|
|
||||||
cors.origins_str = Some(s[2..].try_into().unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !this.expose_headers.is_empty() {
|
|
||||||
cors.expose_headers = Some(
|
|
||||||
this.expose_headers
|
|
||||||
.iter()
|
|
||||||
.fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..]
|
|
||||||
.to_owned(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
CorsFactory {
|
|
||||||
inner: Rc::new(cors),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Middleware for Cross-Origin Resource Sharing support.
|
impl<S, B> Transform<S> for Cors
|
||||||
///
|
|
||||||
/// This struct contains the settings for CORS requests to be validated and for responses to
|
|
||||||
/// be generated.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct CorsFactory {
|
|
||||||
inner: Rc<Inner>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, B> Transform<S> for CorsFactory
|
|
||||||
where
|
where
|
||||||
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||||
S::Future: 'static,
|
S::Future: 'static,
|
||||||
@ -455,13 +495,74 @@ where
|
|||||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||||
|
|
||||||
fn new_transform(&self, service: S) -> Self::Future {
|
fn new_transform(&self, service: S) -> Self::Future {
|
||||||
ok(CorsMiddleware {
|
if let Some(ref err) = self.error {
|
||||||
service,
|
match err {
|
||||||
inner: Rc::clone(&self.inner),
|
Either::A(err) => error!("{}", err),
|
||||||
})
|
Either::B(err) => error!("{}", err),
|
||||||
|
}
|
||||||
|
|
||||||
|
return future::err(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut inner = Rc::clone(&self.inner);
|
||||||
|
|
||||||
|
if inner.supports_credentials
|
||||||
|
&& inner.send_wildcard
|
||||||
|
&& inner.allowed_origins.is_all()
|
||||||
|
{
|
||||||
|
error!("Illegal combination of CORS options: credentials can not be supported when all \
|
||||||
|
origins are allowed and `send_wildcard` is enabled.");
|
||||||
|
return future::err(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// bake allowed headers value if Some and not empty
|
||||||
|
match inner.allowed_headers.as_ref() {
|
||||||
|
Some(header_set) if !header_set.is_empty() => {
|
||||||
|
let allowed_headers_str = intersperse_header_values(header_set);
|
||||||
|
Rc::make_mut(&mut inner).allowed_headers_baked =
|
||||||
|
Some(allowed_headers_str);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bake allowed methods value if not empty
|
||||||
|
if !inner.allowed_methods.is_empty() {
|
||||||
|
let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
|
||||||
|
Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
// bake exposed headers value if Some and not empty
|
||||||
|
match inner.expose_headers.as_ref() {
|
||||||
|
Some(header_set) if !header_set.is_empty() => {
|
||||||
|
let expose_headers_str = intersperse_header_values(header_set);
|
||||||
|
Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
future::ok(CorsMiddleware { service, inner })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Only call when values are guaranteed to be valid header values and set is not empty.
|
||||||
|
fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
|
||||||
|
where
|
||||||
|
T: AsRef<str>,
|
||||||
|
{
|
||||||
|
val_set
|
||||||
|
.iter()
|
||||||
|
.fold(String::with_capacity(32), |mut acc, val| {
|
||||||
|
acc.push_str(", ");
|
||||||
|
acc.push_str(val.as_ref());
|
||||||
|
acc
|
||||||
|
})
|
||||||
|
// set is not empty so string will always have leading ", " to trim
|
||||||
|
[2..]
|
||||||
|
.try_into()
|
||||||
|
// all method names are valid header values
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::convert::{Infallible, TryInto};
|
use std::convert::{Infallible, TryInto};
|
||||||
@ -475,13 +576,20 @@ mod test {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
|
fn illegal_allow_credentials() {
|
||||||
fn cors_validates_illegal_allow_credentials() {
|
// using the permissive defaults (all origins allowed) and adding send_wildcard
|
||||||
let _cors = Cors::new().supports_credentials().send_wildcard().finish();
|
// and supports_credentials should error on construction
|
||||||
|
|
||||||
|
assert!(Cors::permissive()
|
||||||
|
.supports_credentials()
|
||||||
|
.send_wildcard()
|
||||||
|
.new_transform(test::ok_service())
|
||||||
|
.into_inner()
|
||||||
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn default() {
|
async fn restrictive_defaults() {
|
||||||
let mut cors = Cors::default()
|
let mut cors = Cors::default()
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
@ -491,12 +599,12 @@ mod test {
|
|||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn allowed_header_try_from() {
|
async fn allowed_header_try_from() {
|
||||||
let _cors = Cors::new().allowed_header("Content-Type");
|
let _cors = Cors::default().allowed_header("Content-Type");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
@ -511,6 +619,6 @@ mod test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let _cors = Cors::new().allowed_header(ContentType);
|
let _cors = Cors::default().allowed_header(ContentType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,15 +4,16 @@ use derive_more::{Display, Error};
|
|||||||
|
|
||||||
/// Errors that can occur when processing CORS guarded requests.
|
/// Errors that can occur when processing CORS guarded requests.
|
||||||
#[derive(Debug, Clone, Display, Error)]
|
#[derive(Debug, Clone, Display, Error)]
|
||||||
|
#[non_exhaustive]
|
||||||
pub enum CorsError {
|
pub enum CorsError {
|
||||||
|
/// Allowed origin argument must not be wildcard (`*`).
|
||||||
|
#[display(fmt = "`allowed_origin` argument must not be wildcard (`*`).")]
|
||||||
|
WildcardOrigin,
|
||||||
|
|
||||||
/// Request header `Origin` is required but was not provided.
|
/// Request header `Origin` is required but was not provided.
|
||||||
#[display(fmt = "Request header `Origin` is required but was not provided.")]
|
#[display(fmt = "Request header `Origin` is required but was not provided.")]
|
||||||
MissingOrigin,
|
MissingOrigin,
|
||||||
|
|
||||||
/// Request header `Origin` could not be parsed correctly.
|
|
||||||
#[display(fmt = "Request header `Origin` could not be parsed correctly.")]
|
|
||||||
BadOrigin,
|
|
||||||
|
|
||||||
/// Request header `Access-Control-Request-Method` is required but is missing.
|
/// Request header `Access-Control-Request-Method` is required but is missing.
|
||||||
#[display(
|
#[display(
|
||||||
fmt = "Request header `Access-Control-Request-Method` is required but is missing."
|
fmt = "Request header `Access-Control-Request-Method` is required but is missing."
|
||||||
|
@ -1,30 +1,29 @@
|
|||||||
use std::{collections::HashSet, convert::TryInto, fmt};
|
use std::{collections::HashSet, convert::TryFrom, convert::TryInto, fmt, rc::Rc};
|
||||||
|
|
||||||
use actix_web::{
|
use actix_web::{
|
||||||
dev::RequestHead,
|
dev::RequestHead,
|
||||||
error::Result,
|
error::Result,
|
||||||
http::{
|
http::{
|
||||||
self,
|
|
||||||
header::{self, HeaderName, HeaderValue},
|
header::{self, HeaderName, HeaderValue},
|
||||||
Method,
|
Method,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tinyvec::TinyVec;
|
||||||
|
|
||||||
use crate::{AllOrSome, CorsError};
|
use crate::{AllOrSome, CorsError};
|
||||||
|
|
||||||
pub(crate) fn cors<'a>(
|
#[derive(Clone)]
|
||||||
parts: &'a mut Option<Inner>,
|
pub(crate) struct OriginFn {
|
||||||
err: &Option<http::Error>,
|
pub(crate) boxed_fn: Rc<dyn Fn(&RequestHead) -> bool>,
|
||||||
) -> Option<&'a mut Inner> {
|
|
||||||
if err.is_some() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
parts.as_mut()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct OriginFn {
|
impl Default for OriginFn {
|
||||||
pub(crate) boxed_fn: Box<dyn Fn(&RequestHead) -> bool>,
|
/// Dummy default for use in tiny_vec. Do not use.
|
||||||
|
fn default() -> Self {
|
||||||
|
let boxed_fn: Rc<dyn Fn(&_) -> _> = Rc::new(|_req_head| false);
|
||||||
|
Self { boxed_fn }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for OriginFn {
|
impl fmt::Debug for OriginFn {
|
||||||
@ -33,14 +32,28 @@ impl fmt::Debug for OriginFn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
/// Try to parse header value as HTTP method.
|
||||||
|
fn header_value_try_into_method(hdr: &HeaderValue) -> Option<Method> {
|
||||||
|
hdr.to_str()
|
||||||
|
.ok()
|
||||||
|
.and_then(|meth| Method::try_from(meth).ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Inner {
|
pub(crate) struct Inner {
|
||||||
pub(crate) methods: HashSet<Method>,
|
pub(crate) allowed_origins: AllOrSome<HashSet<HeaderValue>>,
|
||||||
pub(crate) origins: AllOrSome<HashSet<String>>,
|
pub(crate) allowed_origins_fns: TinyVec<[OriginFn; 4]>,
|
||||||
pub(crate) origins_fns: Vec<OriginFn>,
|
|
||||||
pub(crate) origins_str: Option<HeaderValue>,
|
pub(crate) allowed_methods: HashSet<Method>,
|
||||||
pub(crate) headers: AllOrSome<HashSet<HeaderName>>,
|
pub(crate) allowed_methods_baked: Option<HeaderValue>,
|
||||||
pub(crate) expose_headers: Option<String>,
|
|
||||||
|
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderName>>,
|
||||||
|
pub(crate) allowed_headers_baked: Option<HeaderValue>,
|
||||||
|
|
||||||
|
/// `All` will echo back `Access-Control-Request-Header` list.
|
||||||
|
pub(crate) expose_headers: AllOrSome<HashSet<HeaderName>>,
|
||||||
|
pub(crate) expose_headers_baked: Option<HeaderValue>,
|
||||||
|
|
||||||
pub(crate) max_age: Option<usize>,
|
pub(crate) max_age: Option<usize>,
|
||||||
pub(crate) preflight: bool,
|
pub(crate) preflight: bool,
|
||||||
pub(crate) send_wildcard: bool,
|
pub(crate) send_wildcard: bool,
|
||||||
@ -48,90 +61,94 @@ pub(crate) struct Inner {
|
|||||||
pub(crate) vary_header: bool,
|
pub(crate) vary_header: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new);
|
||||||
|
|
||||||
impl Inner {
|
impl Inner {
|
||||||
pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
|
pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
|
||||||
if let Some(hdr) = req.headers().get(header::ORIGIN) {
|
// return early if all origins are allowed or get ref to allowed origins set
|
||||||
if let Ok(origin) = hdr.to_str() {
|
#[allow(clippy::mutable_key_type)]
|
||||||
return match self.origins {
|
let allowed_origins = match &self.allowed_origins {
|
||||||
AllOrSome::All => Ok(()),
|
AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(()),
|
||||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
AllOrSome::Some(allowed_origins) => allowed_origins,
|
||||||
.get(origin)
|
// only function origin validators are defined
|
||||||
.map(|_| ())
|
_ => &EMPTY_ORIGIN_SET,
|
||||||
.or_else(|| {
|
};
|
||||||
if self.validate_origin_fns(req) {
|
|
||||||
Some(())
|
// get origin header and try to parse as string
|
||||||
} else {
|
match req.headers().get(header::ORIGIN) {
|
||||||
None
|
// origin header exists and is a string
|
||||||
}
|
Some(origin) => {
|
||||||
})
|
if allowed_origins.contains(origin) || self.validate_origin_fns(req) {
|
||||||
.ok_or(CorsError::OriginNotAllowed),
|
Ok(())
|
||||||
};
|
} else {
|
||||||
}
|
Err(CorsError::OriginNotAllowed)
|
||||||
Err(CorsError::BadOrigin)
|
}
|
||||||
} else {
|
|
||||||
match self.origins {
|
|
||||||
AllOrSome::All => Ok(()),
|
|
||||||
_ => Err(CorsError::MissingOrigin),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// origin header is missing
|
||||||
|
// note: with our implementation, the origin header is required for OPTIONS request or
|
||||||
|
// else this would be unreachable
|
||||||
|
None => Err(CorsError::MissingOrigin),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Accepts origin if _ANY_ functions return true.
|
||||||
pub(crate) fn validate_origin_fns(&self, req: &RequestHead) -> bool {
|
pub(crate) fn validate_origin_fns(&self, req: &RequestHead) -> bool {
|
||||||
self.origins_fns
|
self.allowed_origins_fns
|
||||||
.iter()
|
.iter()
|
||||||
.any(|origin_fn| (origin_fn.boxed_fn)(req))
|
.any(|origin_fn| (origin_fn.boxed_fn)(req))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Only called if origin exists and always after it's validated.
|
||||||
pub(crate) fn access_control_allow_origin(
|
pub(crate) fn access_control_allow_origin(
|
||||||
&self,
|
&self,
|
||||||
req: &RequestHead,
|
req: &RequestHead,
|
||||||
) -> Option<HeaderValue> {
|
) -> Option<HeaderValue> {
|
||||||
match self.origins {
|
let origin = req.headers().get(header::ORIGIN);
|
||||||
|
|
||||||
|
match self.allowed_origins {
|
||||||
AllOrSome::All => {
|
AllOrSome::All => {
|
||||||
if self.send_wildcard {
|
if self.send_wildcard {
|
||||||
Some(HeaderValue::from_static("*"))
|
Some(HeaderValue::from_static("*"))
|
||||||
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
|
|
||||||
Some(origin.clone())
|
|
||||||
} else {
|
} else {
|
||||||
None
|
// see note below about why `.cloned()` is correct
|
||||||
|
origin.cloned()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AllOrSome::Some(ref origins) => {
|
|
||||||
if let Some(origin) =
|
AllOrSome::Some(_) => {
|
||||||
req.headers()
|
// since origin (if it exists) is known to be allowed if this method is called
|
||||||
.get(header::ORIGIN)
|
// then cloning the option is all that is required to be used as an echoed back
|
||||||
.filter(|o| match o.to_str() {
|
// header value (or omitted if None)
|
||||||
Ok(os) => origins.contains(os),
|
origin.cloned()
|
||||||
_ => false,
|
|
||||||
})
|
|
||||||
{
|
|
||||||
Some(origin.clone())
|
|
||||||
} else if self.validate_origin_fns(req) {
|
|
||||||
Some(req.headers().get(header::ORIGIN).unwrap().clone())
|
|
||||||
} else {
|
|
||||||
Some(self.origins_str.as_ref().unwrap().clone())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Use in preflight checks and therefore operates on header list in
|
||||||
|
/// `Access-Control-Request-Headers` not the actual header set.
|
||||||
pub(crate) fn validate_allowed_method(
|
pub(crate) fn validate_allowed_method(
|
||||||
&self,
|
&self,
|
||||||
req: &RequestHead,
|
req: &RequestHead,
|
||||||
) -> Result<(), CorsError> {
|
) -> Result<(), CorsError> {
|
||||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
|
// extract access control header and try to parse as method
|
||||||
if let Ok(meth) = hdr.to_str() {
|
let request_method = req
|
||||||
if let Ok(method) = meth.try_into() {
|
.headers()
|
||||||
return self
|
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
|
||||||
.methods
|
.map(header_value_try_into_method);
|
||||||
.get(&method)
|
|
||||||
.map(|_| ())
|
match request_method {
|
||||||
.ok_or(CorsError::MethodNotAllowed);
|
// method valid and allowed
|
||||||
}
|
Some(Some(method)) if self.allowed_methods.contains(&method) => Ok(()),
|
||||||
}
|
|
||||||
Err(CorsError::BadRequestMethod)
|
// method valid but not allowed
|
||||||
} else {
|
Some(Some(_)) => Err(CorsError::MethodNotAllowed),
|
||||||
Err(CorsError::MissingRequestMethod)
|
|
||||||
|
// method invalid
|
||||||
|
Some(_) => Err(CorsError::BadRequestMethod),
|
||||||
|
|
||||||
|
// method missing
|
||||||
|
None => Err(CorsError::MissingRequestMethod),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,34 +156,54 @@ impl Inner {
|
|||||||
&self,
|
&self,
|
||||||
req: &RequestHead,
|
req: &RequestHead,
|
||||||
) -> Result<(), CorsError> {
|
) -> Result<(), CorsError> {
|
||||||
match self.headers {
|
// return early if all headers are allowed or get ref to allowed origins set
|
||||||
AllOrSome::All => Ok(()),
|
#[allow(clippy::mutable_key_type)]
|
||||||
AllOrSome::Some(ref allowed_headers) => {
|
let allowed_headers = match &self.allowed_headers {
|
||||||
if let Some(hdr) =
|
AllOrSome::All => return Ok(()),
|
||||||
req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
AllOrSome::Some(allowed_headers) => allowed_headers,
|
||||||
{
|
};
|
||||||
if let Ok(headers) = hdr.to_str() {
|
|
||||||
#[allow(clippy::mutable_key_type)] // FIXME: revisit here
|
// extract access control header as string
|
||||||
let mut validated_headers = HashSet::new();
|
// header format should be comma separated header names
|
||||||
for hdr in headers.split(',') {
|
let request_headers = req
|
||||||
match hdr.trim().try_into() {
|
.headers()
|
||||||
Ok(hdr) => validated_headers.insert(hdr),
|
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
||||||
Err(_) => return Err(CorsError::BadRequestHeaders),
|
.map(|hdr| hdr.to_str());
|
||||||
};
|
|
||||||
}
|
match request_headers {
|
||||||
// `Access-Control-Request-Headers` must contain 1 or more `field-name`
|
// header list is valid string
|
||||||
if !validated_headers.is_empty() {
|
Some(Ok(headers)) => {
|
||||||
if !validated_headers.is_subset(allowed_headers) {
|
// the set is ephemeral we take care not to mutate the
|
||||||
return Err(CorsError::HeadersNotAllowed);
|
// inserted keys so this lint exception is acceptable
|
||||||
}
|
#[allow(clippy::mutable_key_type)]
|
||||||
return Ok(());
|
let mut request_headers = HashSet::with_capacity(8);
|
||||||
}
|
|
||||||
}
|
// try to convert each header name in the comma-separated list
|
||||||
Err(CorsError::BadRequestHeaders)
|
for hdr in headers.split(',') {
|
||||||
} else {
|
match hdr.trim().try_into() {
|
||||||
Ok(())
|
Ok(hdr) => request_headers.insert(hdr),
|
||||||
|
Err(_) => return Err(CorsError::BadRequestHeaders),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// header list must contain 1 or more header name
|
||||||
|
if request_headers.is_empty() {
|
||||||
|
return Err(CorsError::BadRequestHeaders);
|
||||||
|
}
|
||||||
|
|
||||||
|
// request header list must be a subset of allowed headers
|
||||||
|
if !request_headers.is_subset(allowed_headers) {
|
||||||
|
return Err(CorsError::HeadersNotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// header list is not a string
|
||||||
|
Some(Err(_)) => Err(CorsError::BadRequestHeaders),
|
||||||
|
|
||||||
|
// header list missing
|
||||||
|
None => Ok(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,40 +214,43 @@ mod test {
|
|||||||
|
|
||||||
use actix_web::{
|
use actix_web::{
|
||||||
dev::Transform,
|
dev::Transform,
|
||||||
http::{header, Method, StatusCode},
|
http::{header, HeaderValue, Method, StatusCode},
|
||||||
test::{self, TestRequest},
|
test::{self, TestRequest},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::Cors;
|
use crate::Cors;
|
||||||
|
|
||||||
|
fn val_as_str(val: &HeaderValue) -> &str {
|
||||||
|
val.to_str().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
#[should_panic(expected = "OriginNotAllowed")]
|
|
||||||
async fn test_validate_not_allowed_origin() {
|
async fn test_validate_not_allowed_origin() {
|
||||||
let cors = Cors::new()
|
let cors = Cors::default()
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let req = TestRequest::with_header("Origin", "https://www.unknown.com")
|
let req = TestRequest::get()
|
||||||
.method(Method::GET)
|
.header(header::ORIGIN, "https://www.unknown.com")
|
||||||
|
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT")
|
||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
|
|
||||||
cors.inner.validate_origin(req.head()).unwrap();
|
assert!(cors.inner.validate_origin(req.head()).is_err());
|
||||||
cors.inner.validate_allowed_method(req.head()).unwrap();
|
assert!(cors.inner.validate_allowed_method(req.head()).is_err());
|
||||||
cors.inner.validate_allowed_headers(req.head()).unwrap();
|
assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_preflight() {
|
async fn test_preflight() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
|
.allow_any_origin()
|
||||||
.send_wildcard()
|
.send_wildcard()
|
||||||
.max_age(3600)
|
.max_age(3600)
|
||||||
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
|
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
|
||||||
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
|
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
|
||||||
.allowed_header(header::CONTENT_TYPE)
|
.allowed_header(header::CONTENT_TYPE)
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -244,24 +284,22 @@ mod test {
|
|||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"*"[..],
|
Some(&b"*"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"3600"[..],
|
Some(&b"3600"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_MAX_AGE)
|
.get(header::ACCESS_CONTROL_MAX_AGE)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let hdr = resp
|
let hdr = resp
|
||||||
.headers()
|
.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
|
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
|
||||||
.unwrap()
|
.map(val_as_str)
|
||||||
.to_str()
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(hdr.contains("authorization"));
|
assert!(hdr.contains("authorization"));
|
||||||
assert!(hdr.contains("accept"));
|
assert!(hdr.contains("accept"));
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
//! Cross-Origin Resource Sharing (CORS) controls for Actix Web.
|
//! Cross-Origin Resource Sharing (CORS) controls for Actix Web.
|
||||||
//!
|
//!
|
||||||
//! This middleware can be applied to both applications and resources. Once built,
|
//! This middleware can be applied to both applications and resources. Once built, a
|
||||||
//! [`CorsFactory`] can be used as a parameter for actix-web `App::wrap()`,
|
//! [`Cors`] builder can be used as an argument for Actix Web's `App::wrap()`,
|
||||||
//! `Scope::wrap()`, or `Resource::wrap()` methods.
|
//! `Scope::wrap()`, or `Resource::wrap()` methods.
|
||||||
//!
|
//!
|
||||||
//! This CORS middleware automatically handles `OPTIONS` preflight requests.
|
//! This CORS middleware automatically handles `OPTIONS` preflight requests.
|
||||||
//!
|
//!
|
||||||
//! # Example
|
//! # Example
|
||||||
//!
|
|
||||||
//! In this example a custom CORS middleware is registered for the "/index.html" endpoint.
|
|
||||||
//!
|
|
||||||
//! ```rust,no_run
|
//! ```rust,no_run
|
||||||
//! use actix_cors::Cors;
|
//! use actix_cors::Cors;
|
||||||
//! use actix_web::{get, http, web, App, HttpRequest, HttpResponse, HttpServer};
|
//! use actix_web::{get, http, web, App, HttpRequest, HttpResponse, HttpServer};
|
||||||
@ -22,7 +19,7 @@
|
|||||||
//! #[actix_web::main]
|
//! #[actix_web::main]
|
||||||
//! async fn main() -> std::io::Result<()> {
|
//! async fn main() -> std::io::Result<()> {
|
||||||
//! HttpServer::new(|| {
|
//! HttpServer::new(|| {
|
||||||
//! let cors = Cors::new()
|
//! let cors = Cors::default()
|
||||||
//! .allowed_origin("https://www.rust-lang.org/")
|
//! .allowed_origin("https://www.rust-lang.org/")
|
||||||
//! .allowed_origin_fn(|req| {
|
//! .allowed_origin_fn(|req| {
|
||||||
//! req.headers
|
//! req.headers
|
||||||
@ -34,8 +31,7 @@
|
|||||||
//! .allowed_methods(vec!["GET", "POST"])
|
//! .allowed_methods(vec!["GET", "POST"])
|
||||||
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
|
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
|
||||||
//! .allowed_header(http::header::CONTENT_TYPE)
|
//! .allowed_header(http::header::CONTENT_TYPE)
|
||||||
//! .max_age(3600)
|
//! .max_age(3600);
|
||||||
//! .finish();
|
|
||||||
//!
|
//!
|
||||||
//! App::new()
|
//! App::new()
|
||||||
//! .wrap(cors)
|
//! .wrap(cors)
|
||||||
@ -61,8 +57,8 @@ mod error;
|
|||||||
mod inner;
|
mod inner;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
|
|
||||||
pub use all_or_some::AllOrSome;
|
use all_or_some::AllOrSome;
|
||||||
pub use builder::{Cors, CorsFactory};
|
pub use builder::Cors;
|
||||||
pub use error::CorsError;
|
pub use error::CorsError;
|
||||||
use inner::{cors, Inner, OriginFn};
|
use inner::{Inner, OriginFn};
|
||||||
pub use middleware::CorsMiddleware;
|
pub use middleware::CorsMiddleware;
|
||||||
|
@ -14,6 +14,7 @@ use actix_web::{
|
|||||||
HttpResponse,
|
HttpResponse,
|
||||||
};
|
};
|
||||||
use futures_util::future::{ok, Either, FutureExt as _, LocalBoxFuture, Ready};
|
use futures_util::future::{ok, Either, FutureExt as _, LocalBoxFuture, Ready};
|
||||||
|
use log::debug;
|
||||||
|
|
||||||
use crate::Inner;
|
use crate::Inner;
|
||||||
|
|
||||||
@ -28,6 +29,93 @@ pub struct CorsMiddleware<S> {
|
|||||||
pub(crate) inner: Rc<Inner>,
|
pub(crate) inner: Rc<Inner>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S> CorsMiddleware<S> {
|
||||||
|
fn handle_preflight<B>(inner: &Inner, req: ServiceRequest) -> ServiceResponse<B> {
|
||||||
|
if let Err(err) = inner
|
||||||
|
.validate_origin(req.head())
|
||||||
|
.and_then(|_| inner.validate_allowed_method(req.head()))
|
||||||
|
.and_then(|_| inner.validate_allowed_headers(req.head()))
|
||||||
|
{
|
||||||
|
return req.error_response(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut res = HttpResponse::Ok();
|
||||||
|
|
||||||
|
if let Some(origin) = inner.access_control_allow_origin(req.head()) {
|
||||||
|
res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref allowed_methods) = inner.allowed_methods_baked {
|
||||||
|
res.header(
|
||||||
|
header::ACCESS_CONTROL_ALLOW_METHODS,
|
||||||
|
allowed_methods.clone(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref headers) = inner.allowed_headers_baked {
|
||||||
|
res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone());
|
||||||
|
} else if let Some(headers) =
|
||||||
|
req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
||||||
|
{
|
||||||
|
// all headers allowed, return
|
||||||
|
res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if inner.supports_credentials {
|
||||||
|
res.header(
|
||||||
|
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
||||||
|
HeaderValue::from_static("true"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_age) = inner.max_age {
|
||||||
|
res.header(header::ACCESS_CONTROL_MAX_AGE, max_age.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = res.finish();
|
||||||
|
let res = res.into_body();
|
||||||
|
req.into_response(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn augment_response<B>(
|
||||||
|
inner: &Inner,
|
||||||
|
mut res: ServiceResponse<B>,
|
||||||
|
) -> ServiceResponse<B> {
|
||||||
|
if let Some(origin) = inner.access_control_allow_origin(res.request().head()) {
|
||||||
|
res.headers_mut()
|
||||||
|
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ref expose) = inner.expose_headers_baked {
|
||||||
|
res.headers_mut()
|
||||||
|
.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if inner.supports_credentials {
|
||||||
|
res.headers_mut().insert(
|
||||||
|
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
||||||
|
HeaderValue::from_static("true"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if inner.vary_header {
|
||||||
|
let value = match res.headers_mut().get(header::VARY) {
|
||||||
|
Some(hdr) => {
|
||||||
|
let mut val: Vec<u8> = Vec::with_capacity(hdr.len() + 8);
|
||||||
|
val.extend(hdr.as_bytes());
|
||||||
|
val.extend(b", Origin");
|
||||||
|
val.try_into().unwrap()
|
||||||
|
}
|
||||||
|
None => HeaderValue::from_static("Origin"),
|
||||||
|
};
|
||||||
|
|
||||||
|
res.headers_mut().insert(header::VARY, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type CorsMiddlewareServiceFuture<B> = Either<
|
type CorsMiddlewareServiceFuture<B> = Either<
|
||||||
Ready<Result<ServiceResponse<B>, Error>>,
|
Ready<Result<ServiceResponse<B>, Error>>,
|
||||||
LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>,
|
LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>,
|
||||||
@ -49,123 +137,84 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: ServiceRequest) -> Self::Future {
|
fn call(&mut self, req: ServiceRequest) -> Self::Future {
|
||||||
if self.inner.preflight && Method::OPTIONS == *req.method() {
|
if self.inner.preflight && req.method() == Method::OPTIONS {
|
||||||
if let Err(e) = self
|
let inner = Rc::clone(&self.inner);
|
||||||
.inner
|
let res = Self::handle_preflight(&inner, req);
|
||||||
.validate_origin(req.head())
|
Either::Left(ok(res))
|
||||||
.and_then(|_| self.inner.validate_allowed_method(req.head()))
|
|
||||||
.and_then(|_| self.inner.validate_allowed_headers(req.head()))
|
|
||||||
{
|
|
||||||
return Either::Left(ok(req.error_response(e)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// allowed headers
|
|
||||||
let headers = if let Some(headers) = self.inner.headers.as_ref() {
|
|
||||||
Some(
|
|
||||||
headers
|
|
||||||
.iter()
|
|
||||||
.fold(String::new(), |s, v| s + "," + v.as_str())
|
|
||||||
.as_str()[1..]
|
|
||||||
.try_into()
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
} else if let Some(hdr) =
|
|
||||||
req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
|
||||||
{
|
|
||||||
Some(hdr.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let res = HttpResponse::Ok()
|
|
||||||
.if_some(self.inner.max_age.as_ref(), |max_age, resp| {
|
|
||||||
let _ = resp.header(
|
|
||||||
header::ACCESS_CONTROL_MAX_AGE,
|
|
||||||
format!("{}", max_age).as_str(),
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.if_some(headers, |headers, resp| {
|
|
||||||
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
|
|
||||||
})
|
|
||||||
.if_some(
|
|
||||||
self.inner.access_control_allow_origin(req.head()),
|
|
||||||
|origin, resp| {
|
|
||||||
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.if_true(self.inner.supports_credentials, |resp| {
|
|
||||||
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
|
|
||||||
})
|
|
||||||
.header(
|
|
||||||
header::ACCESS_CONTROL_ALLOW_METHODS,
|
|
||||||
&self
|
|
||||||
.inner
|
|
||||||
.methods
|
|
||||||
.iter()
|
|
||||||
.fold(String::new(), |s, v| s + "," + v.as_str())
|
|
||||||
.as_str()[1..],
|
|
||||||
)
|
|
||||||
.finish()
|
|
||||||
.into_body();
|
|
||||||
|
|
||||||
Either::Left(ok(req.into_response(res)))
|
|
||||||
} else {
|
} else {
|
||||||
if req.headers().contains_key(header::ORIGIN) {
|
let origin = req.headers().get(header::ORIGIN).cloned();
|
||||||
|
|
||||||
|
if origin.is_some() {
|
||||||
// Only check requests with a origin header.
|
// Only check requests with a origin header.
|
||||||
if let Err(e) = self.inner.validate_origin(req.head()) {
|
if let Err(err) = self.inner.validate_origin(req.head()) {
|
||||||
return Either::Left(ok(req.error_response(e)));
|
debug!("origin validation failed; inner service is not called");
|
||||||
|
return Either::Left(ok(req.error_response(err)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let inner = Rc::clone(&self.inner);
|
let inner = Rc::clone(&self.inner);
|
||||||
let has_origin = req.headers().contains_key(header::ORIGIN);
|
|
||||||
let fut = self.service.call(req);
|
let fut = self.service.call(req);
|
||||||
|
|
||||||
Either::Right(
|
let res = async move {
|
||||||
async move {
|
let res = fut.await;
|
||||||
let res = fut.await;
|
|
||||||
|
|
||||||
if has_origin {
|
if origin.is_some() {
|
||||||
let mut res = res?;
|
let res = res?;
|
||||||
if let Some(origin) =
|
Ok(Self::augment_response(&inner, res))
|
||||||
inner.access_control_allow_origin(res.request().head())
|
} else {
|
||||||
{
|
res
|
||||||
res.headers_mut()
|
|
||||||
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref expose) = inner.expose_headers {
|
|
||||||
res.headers_mut().insert(
|
|
||||||
header::ACCESS_CONTROL_EXPOSE_HEADERS,
|
|
||||||
expose.as_str().try_into().unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if inner.supports_credentials {
|
|
||||||
res.headers_mut().insert(
|
|
||||||
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
|
||||||
HeaderValue::from_static("true"),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if inner.vary_header {
|
|
||||||
let value =
|
|
||||||
if let Some(hdr) = res.headers_mut().get(header::VARY) {
|
|
||||||
let mut val: Vec<u8> =
|
|
||||||
Vec::with_capacity(hdr.as_bytes().len() + 8);
|
|
||||||
val.extend(hdr.as_bytes());
|
|
||||||
val.extend(b", Origin");
|
|
||||||
val.try_into().unwrap()
|
|
||||||
} else {
|
|
||||||
HeaderValue::from_static("Origin")
|
|
||||||
};
|
|
||||||
res.headers_mut().insert(header::VARY, value);
|
|
||||||
}
|
|
||||||
Ok(res)
|
|
||||||
} else {
|
|
||||||
res
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
.boxed_local(),
|
}
|
||||||
)
|
.boxed_local();
|
||||||
|
|
||||||
|
Either::Right(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use actix_web::{
|
||||||
|
dev::Transform,
|
||||||
|
test::{self, TestRequest},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::Cors;
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn test_options_no_origin() {
|
||||||
|
// Tests case where allowed_origins is All but there are validate functions to run incase.
|
||||||
|
// In this case, origins are only allowed when the DNT header is sent.
|
||||||
|
|
||||||
|
let mut cors = Cors::default()
|
||||||
|
.allow_any_origin()
|
||||||
|
.allowed_origin_fn(|req_head| req_head.headers().contains_key(header::DNT))
|
||||||
|
.new_transform(test::ok_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let req = TestRequest::get()
|
||||||
|
.header(header::ORIGIN, "http://example.com")
|
||||||
|
.to_srv_request();
|
||||||
|
let res = cors.call(req).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
None,
|
||||||
|
res.headers()
|
||||||
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
);
|
||||||
|
|
||||||
|
let req = TestRequest::get()
|
||||||
|
.header(header::ORIGIN, "http://example.com")
|
||||||
|
.header(header::DNT, "1")
|
||||||
|
.to_srv_request();
|
||||||
|
let res = cors.call(req).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
Some(&b"http://example.com"[..]),
|
||||||
|
res.headers()
|
||||||
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -10,12 +10,15 @@ use regex::bytes::Regex;
|
|||||||
|
|
||||||
use actix_cors::Cors;
|
use actix_cors::Cors;
|
||||||
|
|
||||||
|
fn val_as_str(val: &HeaderValue) -> &str {
|
||||||
|
val.to_str().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
#[should_panic]
|
#[should_panic]
|
||||||
async fn test_wildcard_origin() {
|
async fn test_wildcard_origin() {
|
||||||
Cors::new()
|
Cors::default()
|
||||||
.allowed_origin("*")
|
.allowed_origin("*")
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -23,7 +26,7 @@ async fn test_wildcard_origin() {
|
|||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_not_allowed_origin_fn() {
|
async fn test_not_allowed_origin_fn() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.allowed_origin_fn(|req| {
|
.allowed_origin_fn(|req| {
|
||||||
req.headers
|
req.headers
|
||||||
@ -32,7 +35,6 @@ async fn test_not_allowed_origin_fn() {
|
|||||||
.filter(|b| b.ends_with(b".unknown.com"))
|
.filter(|b| b.ends_with(b".unknown.com"))
|
||||||
.is_some()
|
.is_some()
|
||||||
})
|
})
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -68,7 +70,7 @@ async fn test_not_allowed_origin_fn() {
|
|||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_allowed_origin_fn() {
|
async fn test_allowed_origin_fn() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.allowed_origin_fn(|req| {
|
.allowed_origin_fn(|req| {
|
||||||
req.headers
|
req.headers
|
||||||
@ -77,7 +79,6 @@ async fn test_allowed_origin_fn() {
|
|||||||
.filter(|b| b.ends_with(b".unknown.com"))
|
.filter(|b| b.ends_with(b".unknown.com"))
|
||||||
.is_some()
|
.is_some()
|
||||||
})
|
})
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -92,8 +93,7 @@ async fn test_allowed_origin_fn() {
|
|||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(val_as_str)
|
||||||
.to_str()
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ async fn test_allowed_origin_fn() {
|
|||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_allowed_origin_fn_with_environment() {
|
async fn test_allowed_origin_fn_with_environment() {
|
||||||
let regex = Regex::new("https:.+\\.unknown\\.com").unwrap();
|
let regex = Regex::new("https:.+\\.unknown\\.com").unwrap();
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.allowed_origin_fn(move |req| {
|
.allowed_origin_fn(move |req| {
|
||||||
req.headers
|
req.headers
|
||||||
@ -123,7 +123,6 @@ async fn test_allowed_origin_fn_with_environment() {
|
|||||||
.filter(|b| regex.is_match(b))
|
.filter(|b| regex.is_match(b))
|
||||||
.is_some()
|
.is_some()
|
||||||
})
|
})
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -138,8 +137,7 @@ async fn test_allowed_origin_fn_with_environment() {
|
|||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(val_as_str)
|
||||||
.to_str()
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -159,11 +157,10 @@ async fn test_allowed_origin_fn_with_environment() {
|
|||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_multiple_origins_preflight() {
|
async fn test_multiple_origins_preflight() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://example.com")
|
.allowed_origin("https://example.com")
|
||||||
.allowed_origin("https://example.org")
|
.allowed_origin("https://example.org")
|
||||||
.allowed_methods(vec![Method::GET])
|
.allowed_methods(vec![Method::GET])
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -175,11 +172,10 @@ async fn test_multiple_origins_preflight() {
|
|||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"https://example.com"[..],
|
Some(&b"https://example.com"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let req = TestRequest::with_header("Origin", "https://example.org")
|
let req = TestRequest::with_header("Origin", "https://example.org")
|
||||||
@ -189,21 +185,19 @@ async fn test_multiple_origins_preflight() {
|
|||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"https://example.org"[..],
|
Some(&b"https://example.org"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_multiple_origins() {
|
async fn test_multiple_origins() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://example.com")
|
.allowed_origin("https://example.com")
|
||||||
.allowed_origin("https://example.org")
|
.allowed_origin("https://example.org")
|
||||||
.allowed_methods(vec![Method::GET])
|
.allowed_methods(vec![Method::GET])
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -214,11 +208,10 @@ async fn test_multiple_origins() {
|
|||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"https://example.com"[..],
|
Some(&b"https://example.com"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let req = TestRequest::with_header("Origin", "https://example.org")
|
let req = TestRequest::with_header("Origin", "https://example.org")
|
||||||
@ -227,18 +220,18 @@ async fn test_multiple_origins() {
|
|||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"https://example.org"[..],
|
Some(&b"https://example.org"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_response() {
|
async fn test_response() {
|
||||||
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
|
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
|
.allow_any_origin()
|
||||||
.send_wildcard()
|
.send_wildcard()
|
||||||
.disable_preflight()
|
.disable_preflight()
|
||||||
.max_age(3600)
|
.max_age(3600)
|
||||||
@ -246,7 +239,6 @@ async fn test_response() {
|
|||||||
.allowed_headers(exposed_headers.clone())
|
.allowed_headers(exposed_headers.clone())
|
||||||
.expose_headers(exposed_headers.clone())
|
.expose_headers(exposed_headers.clone())
|
||||||
.allowed_header(header::CONTENT_TYPE)
|
.allowed_header(header::CONTENT_TYPE)
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -254,18 +246,16 @@ async fn test_response() {
|
|||||||
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||||
.method(Method::OPTIONS)
|
.method(Method::OPTIONS)
|
||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
|
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"*"[..],
|
Some(&b"*"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"Origin"[..],
|
Some(&b"Origin"[..]),
|
||||||
resp.headers().get(header::VARY).unwrap().as_bytes()
|
resp.headers().get(header::VARY).map(HeaderValue::as_bytes)
|
||||||
);
|
);
|
||||||
|
|
||||||
#[allow(clippy::needless_collect)]
|
#[allow(clippy::needless_collect)]
|
||||||
@ -273,20 +263,21 @@ async fn test_response() {
|
|||||||
let headers = resp
|
let headers = resp
|
||||||
.headers()
|
.headers()
|
||||||
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
|
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
|
||||||
.unwrap()
|
.map(val_as_str)
|
||||||
.to_str()
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.split(',')
|
.split(',')
|
||||||
.map(|s| s.trim())
|
.map(|s| s.trim())
|
||||||
.collect::<Vec<&str>>();
|
.collect::<Vec<&str>>();
|
||||||
|
|
||||||
|
// TODO: use HashSet subset check
|
||||||
for h in exposed_headers {
|
for h in exposed_headers {
|
||||||
assert!(headers.contains(&h.as_str()));
|
assert!(headers.contains(&h.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
|
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
|
.allow_any_origin()
|
||||||
.send_wildcard()
|
.send_wildcard()
|
||||||
.disable_preflight()
|
.disable_preflight()
|
||||||
.max_age(3600)
|
.max_age(3600)
|
||||||
@ -294,28 +285,28 @@ async fn test_response() {
|
|||||||
.allowed_headers(exposed_headers.clone())
|
.allowed_headers(exposed_headers.clone())
|
||||||
.expose_headers(exposed_headers.clone())
|
.expose_headers(exposed_headers.clone())
|
||||||
.allowed_header(header::CONTENT_TYPE)
|
.allowed_header(header::CONTENT_TYPE)
|
||||||
.finish()
|
|
||||||
.new_transform(fn_service(|req: ServiceRequest| {
|
.new_transform(fn_service(|req: ServiceRequest| {
|
||||||
ok(req.into_response(
|
ok(req.into_response({
|
||||||
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
|
HttpResponse::Ok().header(header::VARY, "Accept").finish()
|
||||||
))
|
}))
|
||||||
}))
|
}))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||||
.method(Method::OPTIONS)
|
.method(Method::OPTIONS)
|
||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"Accept, Origin"[..],
|
Some(&b"Accept, Origin"[..]),
|
||||||
resp.headers().get(header::VARY).unwrap().as_bytes()
|
resp.headers().get(header::VARY).map(HeaderValue::as_bytes)
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.disable_vary_header()
|
.disable_vary_header()
|
||||||
|
.allowed_methods(vec!["POST"])
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.allowed_origin("https://www.google.com")
|
.allowed_origin("https://www.google.com")
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -325,22 +316,17 @@ async fn test_response() {
|
|||||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
|
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
|
||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
|
|
||||||
let origins_str = resp
|
let origins_str = resp
|
||||||
.headers()
|
.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(val_as_str);
|
||||||
.to_str()
|
assert_eq!(Some("https://www.example.com"), origins_str);
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!("https://www.example.com", origins_str);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_validate_origin() {
|
async fn test_validate_origin() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::default()
|
||||||
.allowed_origin("https://www.example.com")
|
.allowed_origin("https://www.example.com")
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -355,9 +341,8 @@ async fn test_validate_origin() {
|
|||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_no_origin_response() {
|
async fn test_no_origin_response() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::permissive()
|
||||||
.disable_preflight()
|
.disable_preflight()
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -374,18 +359,16 @@ async fn test_no_origin_response() {
|
|||||||
.to_srv_request();
|
.to_srv_request();
|
||||||
let resp = test::call_service(&mut cors, req).await;
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&b"https://www.example.com"[..],
|
Some(&b"https://www.example.com"[..]),
|
||||||
resp.headers()
|
resp.headers()
|
||||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
.unwrap()
|
.map(HeaderValue::as_bytes)
|
||||||
.as_bytes()
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn validate_origin_allows_all_origins() {
|
async fn validate_origin_allows_all_origins() {
|
||||||
let mut cors = Cors::new()
|
let mut cors = Cors::permissive()
|
||||||
.finish()
|
|
||||||
.new_transform(test::ok_service())
|
.new_transform(test::ok_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
Loading…
Reference in New Issue
Block a user