1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 00:21:08 +01:00

complete cors implementation

This commit is contained in:
Nikolay Kim 2018-01-10 13:41:33 -08:00
parent 3f3dcf413b
commit 615db0d9d8
2 changed files with 295 additions and 19 deletions

View File

@ -396,7 +396,7 @@ impl HttpResponseBuilder {
/// This method calls provided closure with builder reference if value is true. /// This method calls provided closure with builder reference if value is true.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: Fn(&mut HttpResponseBuilder) + 'static where F: FnOnce(&mut HttpResponseBuilder)
{ {
if value { if value {
f(self); f(self);
@ -406,7 +406,7 @@ impl HttpResponseBuilder {
/// This method calls provided closure with builder reference if value is Some. /// This method calls provided closure with builder reference if value is Some.
pub fn if_some<T, F>(&mut self, value: Option<&T>, f: F) -> &mut Self pub fn if_some<T, F>(&mut self, value: Option<&T>, f: F) -> &mut Self
where F: Fn(&T, &mut HttpResponseBuilder) + 'static where F: FnOnce(&T, &mut HttpResponseBuilder)
{ {
if let Some(val) = value { if let Some(val) = value {
f(val, self); f(val, self);

View File

@ -45,15 +45,16 @@
//! //!
//! Cors middleware automatically handle *OPTIONS* preflight request. //! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator;
use http::{self, Method, HttpTryFrom, Uri}; use http::{self, Method, HttpTryFrom, Uri};
use http::header::{self, HeaderName}; use http::header::{self, HeaderName, HeaderValue};
use error::{Result, ResponseError}; use error::{Result, ResponseError};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Response, Started};
use httpcodes::{HTTPOk, HTTPBadRequest}; use httpcodes::{HTTPOk, HTTPBadRequest};
use middleware::{Middleware, Response, Started};
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
@ -86,6 +87,13 @@ pub enum Error {
/// One or more headers requested are not allowed /// One or more headers requested are not allowed
#[fail(display="One or more headers requested are not allowed")] #[fail(display="One or more headers requested are not allowed")]
HeadersNotAllowed, HeadersNotAllowed,
}
/// A set of errors that can occur during building CORS middleware
#[derive(Debug, Fail)]
pub enum BuilderError {
#[fail(display="Parse error: {}", _0)]
ParseError(http::Error),
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
/// ///
/// This is a misconfiguration. Check the docuemntation for `Cors`. /// This is a misconfiguration. Check the docuemntation for `Cors`.
@ -93,6 +101,7 @@ pub enum Error {
CredentialsWithWildcardOrigin, CredentialsWithWildcardOrigin,
} }
impl ResponseError for Error { impl ResponseError for Error {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
@ -147,9 +156,34 @@ impl<T> AllOrSome<T> {
pub struct Cors { pub struct Cors {
methods: HashSet<Method>, methods: HashSet<Method>,
origins: AllOrSome<HashSet<Uri>>, origins: AllOrSome<HashSet<Uri>>,
origins_str: Option<HeaderValue>,
headers: AllOrSome<HashSet<HeaderName>>, headers: AllOrSome<HashSet<HeaderName>>,
expose_hdrs: Option<String>,
max_age: Option<usize>, max_age: Option<usize>,
send_wildcards: bool, preflight: bool,
send_wildcard: bool,
supports_credentials: bool,
vary_header: bool,
}
impl Default for Cors {
fn default() -> Cors {
Cors {
origins: AllOrSome::All,
origins_str: None,
methods: HashSet::from_iter(
vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT,
Method::PATCH, Method::DELETE].into_iter()),
headers: AllOrSome::All,
expose_hdrs: None,
max_age: None,
preflight: true,
send_wildcard: false,
supports_credentials: false,
vary_header: true,
}
}
} }
impl Cors { impl Cors {
@ -157,13 +191,19 @@ impl Cors {
CorsBuilder { CorsBuilder {
cors: Some(Cors { cors: Some(Cors {
origins: AllOrSome::All, origins: AllOrSome::All,
origins_str: None,
methods: HashSet::new(), methods: HashSet::new(),
headers: AllOrSome::All, headers: AllOrSome::All,
expose_hdrs: None,
max_age: None, max_age: None,
send_wildcards: false, preflight: true,
send_wildcard: false,
supports_credentials: false,
vary_header: true,
}), }),
methods: false, methods: false,
error: None, error: None,
expose_hdrs: HashSet::new(),
} }
} }
@ -234,31 +274,90 @@ impl Cors {
impl<S> Middleware<S> for Cors { impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
if Method::OPTIONS == *req.method() { if self.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?; self.validate_origin(req)?;
self.validate_allowed_method(req)?; self.validate_allowed_method(req)?;
self.validate_allowed_headers(req)?; self.validate_allowed_headers(req)?;
Ok(Started::Response( Ok(Started::Response(
HTTPOk.build() HTTPOk.build()
.if_some(self.max_age.as_ref(), |max_age, res| { .if_some(self.max_age.as_ref(), |max_age, resp| {
let _ = res.header( let _ = resp.header(
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
.if_some(self.headers.as_ref(), |headers, res| { .if_some(self.headers.as_ref(), |headers, resp| {
let _ = res.header( let _ = resp.header(
header::ACCESS_CONTROL_ALLOW_HEADERS, header::ACCESS_CONTROL_ALLOW_HEADERS,
headers.iter().fold(String::new(), |s, v| s + v.as_str()).as_str());}) &headers.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]);})
.if_true(self.origins.is_all(), |resp| {
if self.send_wildcard {
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
} else {
let origin = req.headers().get(header::ORIGIN).unwrap();
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
})
.if_true(self.origins.is_some(), |resp| {
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone());
})
.if_true(self.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
})
.header( .header(
header::ACCESS_CONTROL_ALLOW_METHODS, header::ACCESS_CONTROL_ALLOW_METHODS,
self.methods.iter().fold(String::new(), |s, v| s + v.as_str()).as_str()) &self.methods.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
.finish() .finish()
.unwrap())) .unwrap()))
} else { } else {
self.validate_origin(req)?;
Ok(Started::Done) Ok(Started::Done)
} }
} }
fn response(&self, _: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> { fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
match self.origins {
AllOrSome::All => {
if self.send_wildcard {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
} else {
let origin = req.headers().get(header::ORIGIN).unwrap();
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
}
AllOrSome::Some(_) => {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone());
}
}
if let Some(ref expose) = self.expose_hdrs {
resp.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap());
}
if self.supports_credentials {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
}
if self.vary_header {
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes());
val.extend(b", Origin");
HeaderValue::try_from(&val[..]).unwrap()
} else {
HeaderValue::from_static("Origin")
};
resp.headers_mut().insert(header::VARY, value);
}
Ok(Response::Done(resp)) Ok(Response::Done(resp))
} }
} }
@ -293,6 +392,7 @@ pub struct CorsBuilder {
cors: Option<Cors>, cors: Option<Cors>,
methods: bool, methods: bool,
error: Option<http::Error>, error: Option<http::Error>,
expose_hdrs: HashSet<HeaderName>,
} }
fn cors<'a>(parts: &'a mut Option<Cors>, err: &Option<http::Error>) -> Option<&'a mut Cors> { fn cors<'a>(parts: &'a mut Option<Cors>, err: &Option<http::Error>) -> Option<&'a mut Cors> {
@ -421,6 +521,30 @@ impl CorsBuilder {
self self
} }
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
/// This corresponds to the `Access-Control-Expose-Headers` responde header.
///
/// This is the `list of exposed headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// This defaults to an empty set.
pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
{
for h in headers {
match HeaderName::try_from(h) {
Ok(method) => {
self.expose_hdrs.insert(method);
},
Err(e) => {
self.error = Some(e.into());
break
}
}
}
self
}
/// Set a maximum time for which this CORS request maybe cached. /// Set a maximum time for which this CORS request maybe cached.
/// This value is set as the `Access-Control-Max-Age` header. /// This value is set as the `Access-Control-Max-Age` header.
/// ///
@ -449,13 +573,60 @@ impl CorsBuilder {
#[cfg_attr(feature = "serialization", serde(default))] #[cfg_attr(feature = "serialization", serde(default))]
pub fn send_wildcard(&mut self) -> &mut CorsBuilder { pub fn send_wildcard(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.send_wildcards = true cors.send_wildcard = true
} }
self self
} }
/// Allows users to make authenticated requests
///
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
/// This allows cookies and credentials to be submitted across domains.
///
/// This option cannot be used in conjuction with an `allowed_origin` set to `All`
/// and `send_wildcards` set to `true`.
///
/// Defaults to `false`.
pub fn supports_credentials(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.supports_credentials = true
}
self
}
/// Disable `Vary` header support.
///
/// When enabled the header `Vary: Origin` will be returned as per the W3
/// implementation guidelines.
///
/// Setting this header when the `Access-Control-Allow-Origin` is
/// dynamically generated (e.g. when there is more than one allowed
/// origin, and an Origin than '*' is returned) informs CDNs and other
/// caches that the CORS headers are dynamic, and cannot be cached.
///
/// By default `vary` header support is enabled.
pub fn disable_vary_header(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.vary_header = false
}
self
}
/// Disable *preflight* request support.
///
/// When enabled cors middleware automatically handles *OPTIONS* request.
/// This is useful application level middleware.
///
/// By default *preflight* support is enabled.
pub fn disable_preflight(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.preflight = false
}
self
}
/// Finishes building and returns the built `Cors` instance. /// Finishes building and returns the built `Cors` instance.
pub fn finish(&mut self) -> Result<Cors, http::Error> { pub fn finish(&mut self) -> Result<Cors, BuilderError> {
if !self.methods { if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD, self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT, Method::POST, Method::OPTIONS, Method::PUT,
@ -463,9 +634,114 @@ impl CorsBuilder {
} }
if let Some(e) = self.error.take() { if let Some(e) = self.error.take() {
return Err(e) return Err(BuilderError::ParseError(e))
} }
Ok(self.cors.take().expect("cannot reuse CorsBuilder")) let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
return Err(BuilderError::CredentialsWithWildcardOrigin)
}
if let AllOrSome::Some(ref origins) = cors.origins {
let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v));
cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap());
}
if !self.expose_hdrs.is_empty() {
cors.expose_hdrs = Some(
self.expose_hdrs.iter().fold(
String::new(), |s, v| s + v.as_str())[1..].to_owned());
}
Ok(cors)
}
}
#[cfg(test)]
mod tests {
use super::*;
use test::TestRequest;
impl Started {
fn is_done(&self) -> bool {
match *self {
Started::Done => true,
_ => false,
}
}
fn response(self) -> HttpResponse {
match self {
Started::Response(resp) => resp,
_ => panic!(),
}
}
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn cors_validates_illegal_allow_credentials() {
Cors::build()
.supports_credentials()
.send_wildcard()
.finish()
.unwrap();
}
#[test]
fn validate_origin_allows_all_origins() {
let cors = Cors::default();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com").finish();
assert!(cors.start(&mut req).ok().unwrap().is_done())
}
#[test]
fn test_preflight() {
let mut cors = Cors::build()
.send_wildcard()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.finish().unwrap();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
.method(Method::OPTIONS)
.finish();
assert!(cors.start(&mut req).is_err());
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
.method(Method::OPTIONS)
.finish();
assert!(cors.start(&mut req).is_err());
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
.method(Method::OPTIONS)
.finish();
let resp = cors.start(&mut req).unwrap().response();
assert_eq!(
&b"*"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
assert_eq!(
&b"3600"[..],
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes());
//assert_eq!(
// &b"authorization,accept,content-type"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes());
//assert_eq!(
// &b"POST,GET,OPTIONS"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
cors.preflight = false;
assert!(cors.start(&mut req).unwrap().is_done());
} }
} }