1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 23:05:56 +01:00

451 lines
13 KiB
Rust
Raw Normal View History

use actix_service::fn_service;
use actix_utils::future::ok;
use actix_web::{
dev::{ServiceRequest, Transform},
2021-12-11 16:05:21 +00:00
http::{
header::{self, HeaderValue},
Method, StatusCode,
},
test::{self, TestRequest},
HttpResponse,
};
use regex::bytes::Regex;
use actix_cors::Cors;
2020-10-19 05:51:31 +01:00
fn val_as_str(val: &HeaderValue) -> &str {
val.to_str().unwrap()
}
#[actix_rt::test]
#[should_panic]
async fn test_wildcard_origin() {
2020-10-19 05:51:31 +01:00
Cors::default()
.allowed_origin("*")
.new_transform(test::ok_service())
.await
.unwrap();
}
#[actix_rt::test]
async fn test_not_allowed_origin_fn() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.allowed_origin_fn(|origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
req.headers
.get(header::ORIGIN)
.map(HeaderValue::as_bytes)
.filter(|b| b.ends_with(b".unknown.com"))
.is_some()
})
.new_transform(test::ok_service())
.await
.unwrap();
{
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
Some(&b"https://www.example.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
{
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.known.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
None,
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
);
}
}
#[actix_rt::test]
async fn test_allowed_origin_fn() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.allowed_origin_fn(|origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
req.headers
.get(header::ORIGIN)
.map(HeaderValue::as_bytes)
.filter(|b| b.ends_with(b".unknown.com"))
.is_some()
})
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
"https://www.example.com",
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(val_as_str)
.unwrap()
);
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.unknown.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
Some(&b"https://www.unknown.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
#[actix_rt::test]
async fn test_allowed_origin_fn_with_environment() {
let regex = Regex::new("https:.+\\.unknown\\.com").unwrap();
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.allowed_origin_fn(move |origin, req| {
assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap());
req.headers
.get(header::ORIGIN)
.map(HeaderValue::as_bytes)
.filter(|b| regex.is_match(b))
.is_some()
})
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
"https://www.example.com",
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(val_as_str)
.unwrap()
);
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.unknown.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
Some(&b"https://www.unknown.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
}
#[actix_rt::test]
async fn test_multiple_origins_preflight() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://example.com")
.allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET])
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET"))
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"https://example.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://example.org"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET"))
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"https://example.org"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
}
#[actix_rt::test]
async fn test_multiple_origins() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://example.com")
.allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET])
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"https://example.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://example.org"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"https://example.org"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
}
#[actix_rt::test]
async fn test_response() {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
2020-10-19 05:51:31 +01:00
.allow_any_origin()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"*"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"Origin"[..]),
resp.headers().get(header::VARY).map(HeaderValue::as_bytes)
);
#[allow(clippy::needless_collect)]
{
let headers = resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
2020-10-19 05:51:31 +01:00
.map(val_as_str)
.unwrap()
.split(',')
.map(|s| s.trim())
.collect::<Vec<&str>>();
2020-10-19 05:51:31 +01:00
// TODO: use HashSet subset check
for h in exposed_headers {
assert!(headers.contains(&h.as_str()));
}
}
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
2020-10-19 05:51:31 +01:00
.allow_any_origin()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.new_transform(fn_service(|req: ServiceRequest| {
2020-10-19 05:51:31 +01:00
ok(req.into_response({
2021-03-21 23:50:26 +01:00
HttpResponse::Ok()
.insert_header((header::VARY, "Accept"))
.finish()
2020-10-19 05:51:31 +01:00
}))
}))
.await
.unwrap();
2020-10-19 05:51:31 +01:00
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"Accept, Origin"[..]),
resp.headers().get(header::VARY).map(HeaderValue::as_bytes)
);
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.disable_vary_header()
2020-10-19 05:51:31 +01:00
.allowed_methods(vec!["POST"])
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS)
2021-03-21 23:50:26 +01:00
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
let origins_str = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(val_as_str);
assert_eq!(Some("https://www.example.com"), origins_str);
}
#[actix_rt::test]
async fn test_validate_origin() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::get()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[actix_rt::test]
async fn test_no_origin_response() {
2021-03-21 23:50:26 +01:00
let cors = Cors::permissive()
.disable_preflight()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert!(resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(
2020-10-19 05:51:31 +01:00
Some(&b"https://www.example.com"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
2020-10-19 05:51:31 +01:00
.map(HeaderValue::as_bytes)
);
}
#[actix_rt::test]
async fn validate_origin_allows_all_origins() {
2021-03-21 23:50:26 +01:00
let cors = Cors::permissive()
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
2020-11-05 19:38:27 +01:00
#[actix_rt::test]
async fn test_allow_any_origin_any_method_any_header() {
2021-03-21 23:50:26 +01:00
let cors = Cors::default()
2020-11-05 19:38:27 +01:00
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.new_transform(test::ok_service())
.await
.unwrap();
2021-03-21 23:50:26 +01:00
let req = TestRequest::default()
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type"))
.insert_header((header::ORIGIN, "https://www.example.com"))
2020-11-05 19:38:27 +01:00
.method(Method::OPTIONS)
.to_srv_request();
2021-03-21 23:50:26 +01:00
let resp = test::call_service(&cors, req).await;
2020-11-05 19:38:27 +01:00
assert_eq!(resp.status(), StatusCode::OK);
}
#[actix_web::test]
async fn expose_all_request_header_values() {
let cors = Cors::permissive()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default()
.insert_header((header::ORIGIN, "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type"))
.insert_header(("X-XSRF-TOKEN", "xsrf-token"))
.to_srv_request();
let resp = test::call_service(&cors, req).await;
assert!(resp
.headers()
.contains_key(header::ACCESS_CONTROL_EXPOSE_HEADERS));
assert!(resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap()
.to_str()
.unwrap()
.contains("xsrf-token"));
}