mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-27 17:22:57 +01:00
actix-cors: Create allowed_origin_fn
builder method (#93)
Co-authored-by: ArmorDarks <git@lavrins.com>
This commit is contained in:
parent
bb8120a8c0
commit
99fe08f332
@ -1,6 +1,7 @@
|
|||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
## Unreleased - 2020-xx-xx
|
## Unreleased - 2020-xx-xx
|
||||||
|
* Implement `allowed_origin_fn` builder method.
|
||||||
|
|
||||||
|
|
||||||
## 0.3.0 - 2020-09-11
|
## 0.3.0 - 2020-09-11
|
||||||
|
@ -25,6 +25,13 @@
|
|||||||
//! .wrap(
|
//! .wrap(
|
||||||
//! Cors::new() // <- Construct CORS middleware builder
|
//! Cors::new() // <- Construct CORS middleware builder
|
||||||
//! .allowed_origin("https://www.rust-lang.org/")
|
//! .allowed_origin("https://www.rust-lang.org/")
|
||||||
|
//! .allowed_origin_fn(|req| {
|
||||||
|
//! req.headers
|
||||||
|
//! .get(http::header::ORIGIN)
|
||||||
|
//! .map(http::HeaderValue::as_bytes)
|
||||||
|
//! .filter(|b| b.ends_with(b".rust-lang.org"))
|
||||||
|
//! .is_some()
|
||||||
|
//! })
|
||||||
//! .allowed_methods(vec!["GET", "POST"])
|
//! .allowed_methods(vec!["GET", "POST"])
|
||||||
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
|
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
|
||||||
//! .allowed_header(http::header::CONTENT_TYPE)
|
//! .allowed_header(http::header::CONTENT_TYPE)
|
||||||
@ -46,6 +53,7 @@
|
|||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
use std::fmt;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
@ -186,6 +194,7 @@ impl Cors {
|
|||||||
cors: Some(Inner {
|
cors: Some(Inner {
|
||||||
origins: AllOrSome::All,
|
origins: AllOrSome::All,
|
||||||
origins_str: None,
|
origins_str: None,
|
||||||
|
origins_fns: Vec::new(),
|
||||||
methods: HashSet::new(),
|
methods: HashSet::new(),
|
||||||
headers: AllOrSome::All,
|
headers: AllOrSome::All,
|
||||||
expose_hdrs: None,
|
expose_hdrs: None,
|
||||||
@ -206,6 +215,7 @@ impl Cors {
|
|||||||
let inner = Inner {
|
let inner = Inner {
|
||||||
origins: AllOrSome::default(),
|
origins: AllOrSome::default(),
|
||||||
origins_str: None,
|
origins_str: None,
|
||||||
|
origins_fns: Vec::new(),
|
||||||
methods: HashSet::from_iter(
|
methods: HashSet::from_iter(
|
||||||
vec![
|
vec![
|
||||||
Method::GET,
|
Method::GET,
|
||||||
@ -248,6 +258,10 @@ impl Cors {
|
|||||||
/// If `send_wildcard` is not set, the client's `Origin` request header
|
/// If `send_wildcard` is not set, the client's `Origin` request header
|
||||||
/// will be echoed back in the `Access-Control-Allow-Origin` response header.
|
/// will be echoed back in the `Access-Control-Allow-Origin` response header.
|
||||||
///
|
///
|
||||||
|
/// If the origin of the request doesn't match any allowed origins and at least
|
||||||
|
/// one `allowed_origin_fn` function is set, these functions will be used
|
||||||
|
/// to determinate allowed origins.
|
||||||
|
///
|
||||||
/// Builder panics if supplied origin is not valid uri.
|
/// Builder panics if supplied origin is not valid uri.
|
||||||
pub fn allowed_origin(mut self, origin: &str) -> Cors {
|
pub fn allowed_origin(mut self, origin: &str) -> Cors {
|
||||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||||
@ -268,6 +282,21 @@ impl Cors {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Determinate allowed origins by processing requests which didn't match any origins
|
||||||
|
/// specified in the `allowed_origin`.
|
||||||
|
///
|
||||||
|
/// The function will receive a `RequestHead` of each request, which can be used
|
||||||
|
/// to determine whether it should be allowed or not.
|
||||||
|
///
|
||||||
|
/// If the function returns `true`, the client's `Origin` request header will be echoed
|
||||||
|
/// back into the `Access-Control-Allow-Origin` response header.
|
||||||
|
pub fn allowed_origin_fn(mut self, f: fn(req: &RequestHead) -> bool) -> Cors {
|
||||||
|
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||||
|
cors.origins_fns.push(OriginFn { f });
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set a list of methods which allowed origins can perform.
|
/// Set a list of methods which allowed origins can perform.
|
||||||
///
|
///
|
||||||
/// This is the `list of methods` in the
|
/// This is the `list of methods` in the
|
||||||
@ -567,10 +596,21 @@ pub struct CorsMiddleware<S> {
|
|||||||
inner: Rc<Inner>,
|
inner: Rc<Inner>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct OriginFn {
|
||||||
|
f: fn(req: &RequestHead) -> bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for OriginFn {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "origin_fn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Inner {
|
struct Inner {
|
||||||
methods: HashSet<Method>,
|
methods: HashSet<Method>,
|
||||||
origins: AllOrSome<HashSet<String>>,
|
origins: AllOrSome<HashSet<String>>,
|
||||||
|
origins_fns: Vec<OriginFn>,
|
||||||
origins_str: Option<HeaderValue>,
|
origins_str: Option<HeaderValue>,
|
||||||
headers: AllOrSome<HashSet<HeaderName>>,
|
headers: AllOrSome<HashSet<HeaderName>>,
|
||||||
expose_hdrs: Option<String>,
|
expose_hdrs: Option<String>,
|
||||||
@ -590,6 +630,13 @@ impl Inner {
|
|||||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||||
.get(origin)
|
.get(origin)
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
|
.or_else(|| {
|
||||||
|
if self.validate_origin_fns(req) {
|
||||||
|
Some(())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
.ok_or(CorsError::OriginNotAllowed),
|
.ok_or(CorsError::OriginNotAllowed),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -602,6 +649,10 @@ impl Inner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_origin_fns(&self, req: &RequestHead) -> bool {
|
||||||
|
self.origins_fns.iter().any(|origin_fn| (origin_fn.f)(req))
|
||||||
|
}
|
||||||
|
|
||||||
fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
|
fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
|
||||||
match self.origins {
|
match self.origins {
|
||||||
AllOrSome::All => {
|
AllOrSome::All => {
|
||||||
@ -623,6 +674,8 @@ impl Inner {
|
|||||||
})
|
})
|
||||||
{
|
{
|
||||||
Some(origin.clone())
|
Some(origin.clone())
|
||||||
|
} else if self.validate_origin_fns(req) {
|
||||||
|
Some(req.headers().get(&header::ORIGIN).unwrap().clone())
|
||||||
} else {
|
} else {
|
||||||
Some(self.origins_str.as_ref().unwrap().clone())
|
Some(self.origins_str.as_ref().unwrap().clone())
|
||||||
}
|
}
|
||||||
@ -1205,4 +1258,98 @@ mod tests {
|
|||||||
.as_bytes()
|
.as_bytes()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn test_allowed_origin_fn() {
|
||||||
|
let mut cors = Cors::new()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.allowed_origin_fn(|req| {
|
||||||
|
req.headers
|
||||||
|
.get(header::ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
.filter(|b| b.ends_with(b".unknown.com"))
|
||||||
|
.is_some()
|
||||||
|
})
|
||||||
|
.finish()
|
||||||
|
.new_transform(test::ok_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||||
|
.method(Method::GET)
|
||||||
|
.to_srv_request();
|
||||||
|
|
||||||
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
"https://www.example.com",
|
||||||
|
resp.headers()
|
||||||
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let req = TestRequest::with_header("Origin", "https://www.unknown.com")
|
||||||
|
.method(Method::GET)
|
||||||
|
.to_srv_request();
|
||||||
|
|
||||||
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
Some(&b"https://www.unknown.com"[..]),
|
||||||
|
resp.headers()
|
||||||
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn test_not_allowed_origin_fn() {
|
||||||
|
let mut cors = Cors::new()
|
||||||
|
.allowed_origin("https://www.example.com")
|
||||||
|
.allowed_origin_fn(|req| {
|
||||||
|
req.headers
|
||||||
|
.get(header::ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
.filter(|b| b.ends_with(b".unknown.com"))
|
||||||
|
.is_some()
|
||||||
|
})
|
||||||
|
.finish()
|
||||||
|
.new_transform(test::ok_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||||
|
.method(Method::GET)
|
||||||
|
.to_srv_request();
|
||||||
|
|
||||||
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
Some(&b"https://www.example.com"[..]),
|
||||||
|
resp.headers()
|
||||||
|
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
.map(HeaderValue::as_bytes)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let req = TestRequest::with_header("Origin", "https://www.known.com")
|
||||||
|
.method(Method::GET)
|
||||||
|
.to_srv_request();
|
||||||
|
|
||||||
|
let resp = test::call_service(&mut cors, req).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
None,
|
||||||
|
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user