mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-23 23:51:06 +01:00
actix-cors: Create allowed_origin_fn
builder method (#93)
Co-authored-by: ArmorDarks <git@lavrins.com>
This commit is contained in:
parent
bb8120a8c0
commit
99fe08f332
@ -1,6 +1,7 @@
|
||||
# Changes
|
||||
|
||||
## Unreleased - 2020-xx-xx
|
||||
* Implement `allowed_origin_fn` builder method.
|
||||
|
||||
|
||||
## 0.3.0 - 2020-09-11
|
||||
|
@ -25,6 +25,13 @@
|
||||
//! .wrap(
|
||||
//! Cors::new() // <- Construct CORS middleware builder
|
||||
//! .allowed_origin("https://www.rust-lang.org/")
|
||||
//! .allowed_origin_fn(|req| {
|
||||
//! req.headers
|
||||
//! .get(http::header::ORIGIN)
|
||||
//! .map(http::HeaderValue::as_bytes)
|
||||
//! .filter(|b| b.ends_with(b".rust-lang.org"))
|
||||
//! .is_some()
|
||||
//! })
|
||||
//! .allowed_methods(vec!["GET", "POST"])
|
||||
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
|
||||
//! .allowed_header(http::header::CONTENT_TYPE)
|
||||
@ -46,6 +53,7 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt;
|
||||
use std::iter::FromIterator;
|
||||
use std::rc::Rc;
|
||||
use std::task::{Context, Poll};
|
||||
@ -186,6 +194,7 @@ impl Cors {
|
||||
cors: Some(Inner {
|
||||
origins: AllOrSome::All,
|
||||
origins_str: None,
|
||||
origins_fns: Vec::new(),
|
||||
methods: HashSet::new(),
|
||||
headers: AllOrSome::All,
|
||||
expose_hdrs: None,
|
||||
@ -206,6 +215,7 @@ impl Cors {
|
||||
let inner = Inner {
|
||||
origins: AllOrSome::default(),
|
||||
origins_str: None,
|
||||
origins_fns: Vec::new(),
|
||||
methods: HashSet::from_iter(
|
||||
vec![
|
||||
Method::GET,
|
||||
@ -248,6 +258,10 @@ impl Cors {
|
||||
/// If `send_wildcard` is not set, the client's `Origin` request header
|
||||
/// will be echoed back in the `Access-Control-Allow-Origin` response header.
|
||||
///
|
||||
/// If the origin of the request doesn't match any allowed origins and at least
|
||||
/// one `allowed_origin_fn` function is set, these functions will be used
|
||||
/// to determinate allowed origins.
|
||||
///
|
||||
/// Builder panics if supplied origin is not valid uri.
|
||||
pub fn allowed_origin(mut self, origin: &str) -> Cors {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
@ -268,6 +282,21 @@ impl Cors {
|
||||
self
|
||||
}
|
||||
|
||||
/// Determinate allowed origins by processing requests which didn't match any origins
|
||||
/// specified in the `allowed_origin`.
|
||||
///
|
||||
/// The function will receive a `RequestHead` of each request, which can be used
|
||||
/// to determine whether it should be allowed or not.
|
||||
///
|
||||
/// If the function returns `true`, the client's `Origin` request header will be echoed
|
||||
/// back into the `Access-Control-Allow-Origin` response header.
|
||||
pub fn allowed_origin_fn(mut self, f: fn(req: &RequestHead) -> bool) -> Cors {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
cors.origins_fns.push(OriginFn { f });
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a list of methods which allowed origins can perform.
|
||||
///
|
||||
/// This is the `list of methods` in the
|
||||
@ -567,10 +596,21 @@ pub struct CorsMiddleware<S> {
|
||||
inner: Rc<Inner>,
|
||||
}
|
||||
|
||||
struct OriginFn {
|
||||
f: fn(req: &RequestHead) -> bool,
|
||||
}
|
||||
|
||||
impl fmt::Debug for OriginFn {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "origin_fn")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Inner {
|
||||
methods: HashSet<Method>,
|
||||
origins: AllOrSome<HashSet<String>>,
|
||||
origins_fns: Vec<OriginFn>,
|
||||
origins_str: Option<HeaderValue>,
|
||||
headers: AllOrSome<HashSet<HeaderName>>,
|
||||
expose_hdrs: Option<String>,
|
||||
@ -590,6 +630,13 @@ impl Inner {
|
||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||
.get(origin)
|
||||
.map(|_| ())
|
||||
.or_else(|| {
|
||||
if self.validate_origin_fns(req) {
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(CorsError::OriginNotAllowed),
|
||||
};
|
||||
}
|
||||
@ -602,6 +649,10 @@ impl Inner {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_origin_fns(&self, req: &RequestHead) -> bool {
|
||||
self.origins_fns.iter().any(|origin_fn| (origin_fn.f)(req))
|
||||
}
|
||||
|
||||
fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
|
||||
match self.origins {
|
||||
AllOrSome::All => {
|
||||
@ -623,6 +674,8 @@ impl Inner {
|
||||
})
|
||||
{
|
||||
Some(origin.clone())
|
||||
} else if self.validate_origin_fns(req) {
|
||||
Some(req.headers().get(&header::ORIGIN).unwrap().clone())
|
||||
} else {
|
||||
Some(self.origins_str.as_ref().unwrap().clone())
|
||||
}
|
||||
@ -1205,4 +1258,98 @@ mod tests {
|
||||
.as_bytes()
|
||||
);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_allowed_origin_fn() {
|
||||
let mut cors = Cors::new()
|
||||
.allowed_origin("https://www.example.com")
|
||||
.allowed_origin_fn(|req| {
|
||||
req.headers
|
||||
.get(header::ORIGIN)
|
||||
.map(HeaderValue::as_bytes)
|
||||
.filter(|b| b.ends_with(b".unknown.com"))
|
||||
.is_some()
|
||||
})
|
||||
.finish()
|
||||
.new_transform(test::ok_service())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::GET)
|
||||
.to_srv_request();
|
||||
|
||||
let resp = test::call_service(&mut cors, req).await;
|
||||
|
||||
assert_eq!(
|
||||
"https://www.example.com",
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let req = TestRequest::with_header("Origin", "https://www.unknown.com")
|
||||
.method(Method::GET)
|
||||
.to_srv_request();
|
||||
|
||||
let resp = test::call_service(&mut cors, req).await;
|
||||
|
||||
assert_eq!(
|
||||
Some(&b"https://www.unknown.com"[..]),
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.map(HeaderValue::as_bytes)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_not_allowed_origin_fn() {
|
||||
let mut cors = Cors::new()
|
||||
.allowed_origin("https://www.example.com")
|
||||
.allowed_origin_fn(|req| {
|
||||
req.headers
|
||||
.get(header::ORIGIN)
|
||||
.map(HeaderValue::as_bytes)
|
||||
.filter(|b| b.ends_with(b".unknown.com"))
|
||||
.is_some()
|
||||
})
|
||||
.finish()
|
||||
.new_transform(test::ok_service())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::GET)
|
||||
.to_srv_request();
|
||||
|
||||
let resp = test::call_service(&mut cors, req).await;
|
||||
|
||||
assert_eq!(
|
||||
Some(&b"https://www.example.com"[..]),
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.map(HeaderValue::as_bytes)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let req = TestRequest::with_header("Origin", "https://www.known.com")
|
||||
.method(Method::GET)
|
||||
.to_srv_request();
|
||||
|
||||
let resp = test::call_service(&mut cors, req).await;
|
||||
|
||||
assert_eq!(
|
||||
None,
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user