use actix_cors::Cors; use actix_utils::future::ok; use actix_web::{ dev::{fn_service, ServiceRequest, Transform}, http::{ header::{self, HeaderValue}, Method, StatusCode, }, test::{self, TestRequest}, HttpResponse, }; use regex::bytes::Regex; fn val_as_str(val: &HeaderValue) -> &str { val.to_str().unwrap() } #[actix_web::test] #[should_panic] async fn test_wildcard_origin() { Cors::default() .allowed_origin("*") .new_transform(test::ok_service()) .await .unwrap(); } #[actix_web::test] async fn test_not_allowed_origin_fn() { let cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(|origin, req| { assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes) .filter(|b| b.ends_with(b".unknown.com")) .is_some() }) .new_transform(test::ok_service()) .await .unwrap(); { let req = TestRequest::get() .insert_header(("Origin", "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://www.example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } { let req = TestRequest::get() .insert_header(("Origin", "https://www.known.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( None, resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN) ); } } #[actix_web::test] async fn test_allowed_origin_fn() { let cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(|origin, req| { assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes) .filter(|b| b.ends_with(b".unknown.com")) .is_some() }) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header(("Origin", "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( "https://www.example.com", resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(val_as_str) .unwrap() ); let req = TestRequest::get() .insert_header(("Origin", "https://www.unknown.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://www.unknown.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } #[actix_web::test] async fn test_allowed_origin_fn_with_environment() { let regex = Regex::new("https:.+\\.unknown\\.com").unwrap(); let cors = Cors::default() .allowed_origin("https://www.example.com") .allowed_origin_fn(move |origin, req| { assert_eq!(&origin, req.headers.get(header::ORIGIN).unwrap()); req.headers .get(header::ORIGIN) .map(HeaderValue::as_bytes) .filter(|b| regex.is_match(b)) .is_some() }) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header(("Origin", "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( "https://www.example.com", resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(val_as_str) .unwrap() ); let req = TestRequest::get() .insert_header(("Origin", "https://www.unknown.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://www.unknown.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } #[actix_web::test] async fn test_multiple_origins_preflight() { let cors = Cors::default() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .insert_header(("Origin", "https://example.com")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); let req = TestRequest::default() .insert_header(("Origin", "https://example.org")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://example.org"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } #[actix_web::test] async fn test_multiple_origins() { let cors = Cors::default() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header(("Origin", "https://example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); let req = TestRequest::get() .insert_header(("Origin", "https://example.org")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://example.org"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } #[actix_web::test] async fn test_response() { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let cors = Cors::default() .allow_any_origin() .send_wildcard() .disable_preflight() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .insert_header(("Origin", "https://www.example.com")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"*"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers().get(header::VARY).map(HeaderValue::as_bytes), Some(&b"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"[..]), ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers().get(header::VARY).map(HeaderValue::as_bytes), Some(&b"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"[..]), ); #[allow(clippy::needless_collect)] { let headers = resp .headers() .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) .map(val_as_str) .unwrap() .split(',') .map(|s| s.trim()) .collect::>(); // TODO: use HashSet subset check for h in exposed_headers { assert!(headers.contains(&h.as_str())); } } let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let cors = Cors::default() .allow_any_origin() .send_wildcard() .disable_preflight() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) .new_transform(fn_service(|req: ServiceRequest| { ok(req.into_response({ HttpResponse::Ok() .insert_header((header::VARY, "Accept")) .finish() })) })) .await .unwrap(); let req = TestRequest::default() .insert_header(("Origin", "https://www.example.com")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) .map(HeaderValue::as_bytes) .unwrap(), b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers().get(header::VARY).map(HeaderValue::as_bytes).unwrap(), b"Accept, Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", ); let cors = Cors::default() .disable_vary_header() .allowed_methods(vec!["POST"]) .allowed_origin("https://www.example.com") .allowed_origin("https://www.google.com") .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .insert_header(("Origin", "https://www.example.com")) .method(Method::OPTIONS) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) .to_srv_request(); let resp = test::call_service(&cors, req).await; let origins_str = resp .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(val_as_str); assert_eq!(Some("https://www.example.com"), origins_str); } #[actix_web::test] async fn test_validate_origin() { let cors = Cors::default() .allowed_origin("https://www.example.com") .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header(("Origin", "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_web::test] async fn test_blocks_mismatched_origin_by_default() { let cors = Cors::default() .allowed_origin("https://www.example.com") .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::get() .insert_header(("Origin", "https://www.example.test")) .to_srv_request(); let res = test::call_service(&cors, req).await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); assert!(res .headers() .get(header::ACCESS_CONTROL_ALLOW_METHODS) .is_none()); } #[actix_web::test] async fn test_mismatched_origin_block_turned_off() { let cors = Cors::default() .allow_any_method() .allowed_origin("https://www.example.com") .block_on_origin_mismatch(false) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .method(Method::OPTIONS) .insert_header(("Origin", "https://wrong.com")) .insert_header(("Access-Control-Request-Method", "POST")) .to_srv_request(); let res = test::call_service(&cors, req).await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); let req = TestRequest::get() .insert_header(("Origin", "https://wrong.com")) .to_srv_request(); let res = test::call_service(&cors, req).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), None); } #[actix_web::test] async fn test_no_origin_response() { let cors = Cors::permissive() .disable_preflight() .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default().method(Method::GET).to_srv_request(); let resp = test::call_service(&cors, req).await; assert!(resp .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .is_none()); let req = TestRequest::default() .insert_header(("Origin", "https://www.example.com")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( Some(&b"https://www.example.com"[..]), resp.headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .map(HeaderValue::as_bytes) ); } #[actix_web::test] async fn validate_origin_allows_all_origins() { let cors = Cors::permissive() .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .insert_header(("Origin", "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_web::test] async fn vary_header_on_all_handled_responses() { let cors = Cors::permissive() .new_transform(test::ok_service()) .await .unwrap(); // preflight request let req = TestRequest::default() .method(Method::OPTIONS) .insert_header((header::ORIGIN, "https://www.example.com")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "GET")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); assert!(resp .headers() .contains_key(header::ACCESS_CONTROL_ALLOW_METHODS)); #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", ); // follow-up regular request let req = TestRequest::default() .method(Method::PUT) .insert_header((header::ORIGIN, "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", ); let cors = Cors::default() .allow_any_method() .new_transform(test::ok_service()) .await .unwrap(); // regular request bad origin let req = TestRequest::default() .method(Method::PUT) .insert_header((header::ORIGIN, "https://www.example.com")) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", ); // regular request no origin let req = TestRequest::default().method(Method::PUT).to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); #[cfg(not(feature = "draft-private-network-access"))] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); #[cfg(feature = "draft-private-network-access")] assert_eq!( resp.headers() .get(header::VARY) .expect("response should have Vary header") .to_str() .unwrap(), "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", ); } #[actix_web::test] async fn test_allow_any_origin_any_method_any_header() { let cors = Cors::default() .allow_any_origin() .allow_any_method() .allow_any_header() .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default() .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")) .insert_header((header::ORIGIN, "https://www.example.com")) .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_web::test] async fn expose_all_request_header_values() { let cors = Cors::permissive() .new_transform(fn_service(|req: ServiceRequest| async move { let res = req.into_response( HttpResponse::Ok() .insert_header((header::CONTENT_DISPOSITION, "test disposition")) .finish(), ); Ok(res) })) .await .unwrap(); let req = TestRequest::default() .insert_header((header::ORIGIN, "https://www.example.com")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "content-type")) .to_srv_request(); let res = test::call_service(&cors, req).await; let cd_hdr = res .headers() .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) .unwrap() .to_str() .unwrap(); assert!(cd_hdr.contains("content-disposition")); assert!(cd_hdr.contains("access-control-allow-origin")); } #[cfg(feature = "draft-private-network-access")] #[actix_web::test] async fn private_network_access() { let cors = Cors::permissive() .allowed_origin("https://public.site") .allow_private_network_access() .new_transform(fn_service(|req: ServiceRequest| async move { let res = req.into_response( HttpResponse::Ok() .insert_header((header::CONTENT_DISPOSITION, "test disposition")) .finish(), ); Ok(res) })) .await .unwrap(); let req = TestRequest::default() .insert_header((header::ORIGIN, "https://public.site")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) .insert_header((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true")) .to_srv_request(); let res = test::call_service(&cors, req).await; assert!(res.headers().contains_key("access-control-allow-origin")); let req = TestRequest::default() .insert_header((header::ORIGIN, "https://public.site")) .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST")) .insert_header((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true")) .insert_header(("Access-Control-Request-Private-Network", "true")) .to_srv_request(); let res = test::call_service(&cors, req).await; assert!(res.headers().contains_key("access-control-allow-origin")); assert!(res .headers() .contains_key("access-control-allow-private-network")); }