1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 14:55:56 +01:00

actix-cors: Create allowed_origin_fn builder method (#93)

Co-authored-by: ArmorDarks <git@lavrins.com>
This commit is contained in:
Roman Lakhtadyr 2020-09-25 02:36:53 +03:00 committed by GitHub
parent bb8120a8c0
commit 99fe08f332
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 148 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# Changes
## Unreleased - 2020-xx-xx
* Implement `allowed_origin_fn` builder method.
## 0.3.0 - 2020-09-11

View File

@ -25,6 +25,13 @@
//! .wrap(
//! Cors::new() // <- Construct CORS middleware builder
//! .allowed_origin("https://www.rust-lang.org/")
//! .allowed_origin_fn(|req| {
//! req.headers
//! .get(http::header::ORIGIN)
//! .map(http::HeaderValue::as_bytes)
//! .filter(|b| b.ends_with(b".rust-lang.org"))
//! .is_some()
//! })
//! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE)
@ -46,6 +53,7 @@
use std::collections::HashSet;
use std::convert::TryFrom;
use std::fmt;
use std::iter::FromIterator;
use std::rc::Rc;
use std::task::{Context, Poll};
@ -186,6 +194,7 @@ impl Cors {
cors: Some(Inner {
origins: AllOrSome::All,
origins_str: None,
origins_fns: Vec::new(),
methods: HashSet::new(),
headers: AllOrSome::All,
expose_hdrs: None,
@ -206,6 +215,7 @@ impl Cors {
let inner = Inner {
origins: AllOrSome::default(),
origins_str: None,
origins_fns: Vec::new(),
methods: HashSet::from_iter(
vec![
Method::GET,
@ -248,6 +258,10 @@ impl Cors {
/// If `send_wildcard` is not set, the client's `Origin` request header
/// will be echoed back in the `Access-Control-Allow-Origin` response header.
///
/// If the origin of the request doesn't match any allowed origins and at least
/// one `allowed_origin_fn` function is set, these functions will be used
/// to determinate allowed origins.
///
/// Builder panics if supplied origin is not valid uri.
pub fn allowed_origin(mut self, origin: &str) -> Cors {
if let Some(cors) = cors(&mut self.cors, &self.error) {
@ -268,6 +282,21 @@ impl Cors {
self
}
/// Determinate allowed origins by processing requests which didn't match any origins
/// specified in the `allowed_origin`.
///
/// The function will receive a `RequestHead` of each request, which can be used
/// to determine whether it should be allowed or not.
///
/// If the function returns `true`, the client's `Origin` request header will be echoed
/// back into the `Access-Control-Allow-Origin` response header.
pub fn allowed_origin_fn(mut self, f: fn(req: &RequestHead) -> bool) -> Cors {
if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.origins_fns.push(OriginFn { f });
}
self
}
/// Set a list of methods which allowed origins can perform.
///
/// This is the `list of methods` in the
@ -567,10 +596,21 @@ pub struct CorsMiddleware<S> {
inner: Rc<Inner>,
}
struct OriginFn {
f: fn(req: &RequestHead) -> bool,
}
impl fmt::Debug for OriginFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "origin_fn")
}
}
#[derive(Debug)]
struct Inner {
methods: HashSet<Method>,
origins: AllOrSome<HashSet<String>>,
origins_fns: Vec<OriginFn>,
origins_str: Option<HeaderValue>,
headers: AllOrSome<HashSet<HeaderName>>,
expose_hdrs: Option<String>,
@ -590,6 +630,13 @@ impl Inner {
AllOrSome::Some(ref allowed_origins) => allowed_origins
.get(origin)
.map(|_| ())
.or_else(|| {
if self.validate_origin_fns(req) {
Some(())
} else {
None
}
})
.ok_or(CorsError::OriginNotAllowed),
};
}
@ -602,6 +649,10 @@ impl Inner {
}
}
fn validate_origin_fns(&self, req: &RequestHead) -> bool {
self.origins_fns.iter().any(|origin_fn| (origin_fn.f)(req))
}
fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
match self.origins {
AllOrSome::All => {
@ -623,6 +674,8 @@ impl Inner {
})
{
Some(origin.clone())
} else if self.validate_origin_fns(req) {
Some(req.headers().get(&header::ORIGIN).unwrap().clone())
} else {
Some(self.origins_str.as_ref().unwrap().clone())
}
@ -1205,4 +1258,98 @@ mod tests {
.as_bytes()
);
}
#[actix_rt::test]
async fn test_allowed_origin_fn() {
let mut cors = Cors::new()
.allowed_origin("https://www.example.com")
.allowed_origin_fn(|req| {
req.headers
.get(header::ORIGIN)
.map(HeaderValue::as_bytes)
.filter(|b| b.ends_with(b".unknown.com"))
.is_some()
})
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
{
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
"https://www.example.com",
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.to_str()
.unwrap()
);
}
{
let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
Some(&b"https://www.unknown.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
}
#[actix_rt::test]
async fn test_not_allowed_origin_fn() {
let mut cors = Cors::new()
.allowed_origin("https://www.example.com")
.allowed_origin_fn(|req| {
req.headers
.get(header::ORIGIN)
.map(HeaderValue::as_bytes)
.filter(|b| b.ends_with(b".unknown.com"))
.is_some()
})
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
{
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
Some(&b"https://www.example.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
{
let req = TestRequest::with_header("Origin", "https://www.known.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
None,
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
);
}
}
}