1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-18 13:51:50 +01:00

migrate actix-cors

This commit is contained in:
Nikolay Kim 2019-11-21 10:54:07 +06:00
parent 3646725cf6
commit 6ac4ac66b9
3 changed files with 383 additions and 345 deletions

View File

@ -31,7 +31,7 @@ members = [
".", ".",
"awc", "awc",
"actix-http", "actix-http",
#"actix-cors", "actix-cors",
#"actix-files", #"actix-files",
#"actix-framed", #"actix-framed",
#"actix-session", #"actix-session",

View File

@ -17,7 +17,7 @@ name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = "1.0.9" actix-web = "2.0.0-alpha.1"
actix-service = "0.4.0" actix-service = "1.0.0-alpha.1"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.3.1" futures = "0.3.1"

View File

@ -23,7 +23,8 @@
//! .allowed_methods(vec!["GET", "POST"]) //! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE) //! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600)) //! .max_age(3600)
//! .finish())
//! .service( //! .service(
//! web::resource("/index.html") //! web::resource("/index.html")
//! .route(web::get().to(index)) //! .route(web::get().to(index))
@ -41,16 +42,16 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{IntoTransform, Service, Transform}; use actix_service::{Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result}; use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use derive_more::Display; use derive_more::Display;
use futures::future::{ok, Either, Future, FutureResult}; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::Poll;
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Display)] #[derive(Debug, Display)]
@ -456,25 +457,9 @@ impl Cors {
} }
self self
} }
}
fn cors<'a>( /// Construct cors middleware
parts: &'a mut Option<Inner>, pub fn finish(self) -> CorsFactory {
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl<S, B> IntoTransform<CorsFactory, S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn into_transform(self) -> CorsFactory {
let mut slf = if !self.methods { let mut slf = if !self.methods {
self.allowed_methods(vec![ self.allowed_methods(vec![
Method::GET, Method::GET,
@ -521,6 +506,16 @@ where
} }
} }
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
/// `Middleware` for Cross-origin resource sharing support /// `Middleware` for Cross-origin resource sharing support
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
@ -540,7 +535,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware { ok(CorsMiddleware {
@ -682,12 +677,12 @@ where
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<
FutureResult<Self::Response, Error>, Ready<Result<Self::Response, Error>>,
Either<S::Future, Box<dyn Future<Item = Self::Response, Error = Error>>>, LocalBoxFuture<'static, Result<Self::Response, Error>>,
>; >;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready() self.service.poll_ready(cx)
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@ -698,7 +693,7 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head()))
{ {
return Either::A(ok(req.error_response(e))); return Either::Left(ok(req.error_response(e)));
} }
// allowed headers // allowed headers
@ -751,22 +746,32 @@ where
.finish() .finish()
.into_body(); .into_body();
Either::A(ok(req.into_response(res))) Either::Left(ok(req.into_response(res)))
} else if req.headers().contains_key(&header::ORIGIN) { } else {
if req.headers().contains_key(&header::ORIGIN) {
// Only check requests with a origin header. // Only check requests with a origin header.
if let Err(e) = self.inner.validate_origin(req.head()) { if let Err(e) = self.inner.validate_origin(req.head()) {
return Either::A(ok(req.error_response(e))); return Either::Left(ok(req.error_response(e)));
}
} }
let inner = self.inner.clone(); let inner = self.inner.clone();
let has_origin = req.headers().contains_key(&header::ORIGIN);
let fut = self.service.call(req);
Either::B(Either::B(Box::new(self.service.call(req).and_then( Either::Right(
move |mut res| { async move {
let res = fut.await;
if has_origin {
let mut res = res?;
if let Some(origin) = if let Some(origin) =
inner.access_control_allow_origin(res.request().head()) inner.access_control_allow_origin(res.request().head())
{ {
res.headers_mut() res.headers_mut().insert(
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
}; };
if let Some(ref expose) = inner.expose_hdrs { if let Some(ref expose) = inner.expose_hdrs {
@ -782,8 +787,9 @@ where
); );
} }
if inner.vary_header { if inner.vary_header {
let value = let value = if let Some(hdr) =
if let Some(hdr) = res.headers_mut().get(&header::VARY) { res.headers_mut().get(&header::VARY)
{
let mut val: Vec<u8> = let mut val: Vec<u8> =
Vec::with_capacity(hdr.as_bytes().len() + 8); Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
@ -795,80 +801,73 @@ where
res.headers_mut().insert(header::VARY, value); res.headers_mut().insert(header::VARY, value);
} }
Ok(res) Ok(res)
},
))))
} else { } else {
Either::B(Either::A(self.service.call(req))) res
}
}
.boxed_local(),
)
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{IntoService, Transform}; use actix_service::{service_fn2, Transform};
use actix_web::test::{self, block_on, TestRequest}; use actix_web::test::{self, block_on, TestRequest};
use super::*; use super::*;
impl Cors {
fn finish<F, S, B>(self, srv: F) -> CorsMiddleware<S>
where
F: IntoService<S>,
S: Service<
Request = ServiceRequest,
Response = ServiceResponse<B>,
Error = Error,
> + 'static,
S::Future: 'static,
B: 'static,
{
block_on(
IntoTransform::<CorsFactory, S>::into_transform(self)
.new_transform(srv.into_service()),
)
.unwrap()
}
}
#[test] #[test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() { fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new() let _cors = Cors::new().supports_credentials().send_wildcard().finish();
.supports_credentials()
.send_wildcard()
.finish(test::ok_service());
} }
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let mut cors = Cors::new().finish(test::ok_service()); block_on(async {
let mut cors = Cors::new()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn default() { fn default() {
let mut cors = block_on(async {
block_on(Cors::default().new_transform(test::ok_service())).unwrap(); let mut cors = Cors::default()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_preflight() { fn test_preflight() {
block_on(async {
let mut cors = Cors::new() let mut cors = Cors::new()
.send_wildcard() .send_wildcard()
.max_age(3600) .max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
@ -877,7 +876,7 @@ mod tests {
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
@ -897,7 +896,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@ -943,8 +942,9 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
// #[test] // #[test]
@ -960,9 +960,13 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
block_on(async {
let cors = Cors::new() let cors = Cors::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
@ -970,28 +974,40 @@ mod tests {
cors.inner.validate_origin(req.head()).unwrap(); cors.inner.validate_origin(req.head()).unwrap();
cors.inner.validate_allowed_method(req.head()).unwrap(); cors.inner.validate_allowed_method(req.head()).unwrap();
cors.inner.validate_allowed_headers(req.head()).unwrap(); cors.inner.validate_allowed_headers(req.head()).unwrap();
})
} }
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
block_on(async {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_no_origin_response() { fn test_no_origin_response() {
let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); block_on(async {
let mut cors = Cors::new()
.disable_preflight()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
@ -1000,7 +1016,7 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://www.example.com"[..], &b"https://www.example.com"[..],
resp.headers() resp.headers()
@ -1008,10 +1024,12 @@ mod tests {
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
#[test] #[test]
fn test_response() { fn test_response() {
block_on(async {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new() let mut cors = Cors::new()
.send_wildcard() .send_wildcard()
@ -1021,13 +1039,16 @@ mod tests {
.allowed_headers(exposed_headers.clone()) .allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@ -1065,15 +1086,18 @@ mod tests {
.allowed_headers(exposed_headers.clone()) .allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish(|req: ServiceRequest| { .finish()
req.into_response( .new_transform(service_fn2(|req: ServiceRequest| {
ok(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(), HttpResponse::Ok().header(header::VARY, "Accept").finish(),
) ))
}); }))
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"Accept, Origin"[..], &b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes() resp.headers().get(header::VARY).unwrap().as_bytes()
@ -1083,13 +1107,16 @@ mod tests {
.disable_vary_header() .disable_vary_header()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com") .allowed_origin("https://www.google.com")
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
let origins_str = resp let origins_str = resp
.headers() .headers()
@ -1099,21 +1126,26 @@ mod tests {
.unwrap(); .unwrap();
assert_eq!("https://www.example.com", origins_str); assert_eq!("https://www.example.com", origins_str);
})
} }
#[test] #[test]
fn test_multiple_origins() { fn test_multiple_origins() {
block_on(async {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@ -1126,7 +1158,7 @@ mod tests {
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
@ -1134,22 +1166,27 @@ mod tests {
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
#[test] #[test]
fn test_multiple_origins_preflight() { fn test_multiple_origins_preflight() {
block_on(async {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
.finish(test::ok_service()); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@ -1163,7 +1200,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
@ -1171,5 +1208,6 @@ mod tests {
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
} }