diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 94ed878e..63582aeb 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -396,7 +396,7 @@ impl HttpResponseBuilder { /// This method calls provided closure with builder reference if value is true. pub fn if_true(&mut self, value: bool, f: F) -> &mut Self - where F: Fn(&mut HttpResponseBuilder) + 'static + where F: FnOnce(&mut HttpResponseBuilder) { if value { f(self); @@ -406,7 +406,7 @@ impl HttpResponseBuilder { /// This method calls provided closure with builder reference if value is Some. pub fn if_some(&mut self, value: Option<&T>, f: F) -> &mut Self - where F: Fn(&T, &mut HttpResponseBuilder) + 'static + where F: FnOnce(&T, &mut HttpResponseBuilder) { if let Some(val) = value { f(val, self); diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 630a6eb7..a4a011d9 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -45,15 +45,16 @@ //! //! Cors middleware automatically handle *OPTIONS* preflight request. use std::collections::HashSet; +use std::iter::FromIterator; use http::{self, Method, HttpTryFrom, Uri}; -use http::header::{self, HeaderName}; +use http::header::{self, HeaderName, HeaderValue}; use error::{Result, ResponseError}; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use middleware::{Middleware, Response, Started}; use httpcodes::{HTTPOk, HTTPBadRequest}; +use middleware::{Middleware, Response, Started}; /// A set of errors that can occur during processing CORS #[derive(Debug, Fail)] @@ -86,6 +87,13 @@ pub enum Error { /// One or more headers requested are not allowed #[fail(display="One or more headers requested are not allowed")] HeadersNotAllowed, +} + +/// A set of errors that can occur during building CORS middleware +#[derive(Debug, Fail)] +pub enum BuilderError { + #[fail(display="Parse error: {}", _0)] + ParseError(http::Error), /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C /// /// This is a misconfiguration. Check the docuemntation for `Cors`. @@ -93,6 +101,7 @@ pub enum Error { CredentialsWithWildcardOrigin, } + impl ResponseError for Error { fn error_response(&self) -> HttpResponse { @@ -147,9 +156,34 @@ impl AllOrSome { pub struct Cors { methods: HashSet, origins: AllOrSome>, + origins_str: Option, headers: AllOrSome>, + expose_hdrs: Option, max_age: Option, - send_wildcards: bool, + preflight: bool, + send_wildcard: bool, + supports_credentials: bool, + vary_header: bool, +} + +impl Default for Cors { + fn default() -> Cors { + Cors { + origins: AllOrSome::All, + origins_str: None, + methods: HashSet::from_iter( + vec![Method::GET, Method::HEAD, + Method::POST, Method::OPTIONS, Method::PUT, + Method::PATCH, Method::DELETE].into_iter()), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + } + } } impl Cors { @@ -157,13 +191,19 @@ impl Cors { CorsBuilder { cors: Some(Cors { origins: AllOrSome::All, + origins_str: None, methods: HashSet::new(), headers: AllOrSome::All, + expose_hdrs: None, max_age: None, - send_wildcards: false, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, }), methods: false, error: None, + expose_hdrs: HashSet::new(), } } @@ -234,31 +274,90 @@ impl Cors { impl Middleware for Cors { fn start(&self, req: &mut HttpRequest) -> Result { - if Method::OPTIONS == *req.method() { + if self.preflight && Method::OPTIONS == *req.method() { self.validate_origin(req)?; self.validate_allowed_method(req)?; self.validate_allowed_headers(req)?; Ok(Started::Response( HTTPOk.build() - .if_some(self.max_age.as_ref(), |max_age, res| { - let _ = res.header( + .if_some(self.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) - .if_some(self.headers.as_ref(), |headers, res| { - let _ = res.header( + .if_some(self.headers.as_ref(), |headers, resp| { + let _ = resp.header( header::ACCESS_CONTROL_ALLOW_HEADERS, - headers.iter().fold(String::new(), |s, v| s + v.as_str()).as_str());}) + &headers.iter().fold( + String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]);}) + .if_true(self.origins.is_all(), |resp| { + if self.send_wildcard { + resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); + } else { + let origin = req.headers().get(header::ORIGIN).unwrap(); + resp.header( + header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); + } + }) + .if_true(self.origins.is_some(), |resp| { + resp.header( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + self.origins_str.as_ref().unwrap().clone()); + }) + .if_true(self.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) .header( header::ACCESS_CONTROL_ALLOW_METHODS, - self.methods.iter().fold(String::new(), |s, v| s + v.as_str()).as_str()) + &self.methods.iter().fold( + String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]) .finish() .unwrap())) } else { + self.validate_origin(req)?; + Ok(Started::Done) } } - fn response(&self, _: &mut HttpRequest, resp: HttpResponse) -> Result { + fn response(&self, req: &mut HttpRequest, mut resp: HttpResponse) -> Result { + match self.origins { + AllOrSome::All => { + if self.send_wildcard { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + } else { + let origin = req.headers().get(header::ORIGIN).unwrap(); + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); + } + } + AllOrSome::Some(_) => { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + self.origins_str.as_ref().unwrap().clone()); + } + } + + if let Some(ref expose) = self.expose_hdrs { + resp.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap()); + } + if self.supports_credentials { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); + } + if self.vary_header { + let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { + let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + HeaderValue::try_from(&val[..]).unwrap() + } else { + HeaderValue::from_static("Origin") + }; + resp.headers_mut().insert(header::VARY, value); + } Ok(Response::Done(resp)) } } @@ -293,6 +392,7 @@ pub struct CorsBuilder { cors: Option, methods: bool, error: Option, + expose_hdrs: HashSet, } fn cors<'a>(parts: &'a mut Option, err: &Option) -> Option<&'a mut Cors> { @@ -421,6 +521,30 @@ impl CorsBuilder { self } + /// Set a list of headers which are safe to expose to the API of a CORS API specification. + /// This corresponds to the `Access-Control-Expose-Headers` responde header. + /// + /// This is the `list of exposed headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This defaults to an empty set. + pub fn expose_headers(&mut self, headers: U) -> &mut CorsBuilder + where U: IntoIterator, HeaderName: HttpTryFrom + { + for h in headers { + match HeaderName::try_from(h) { + Ok(method) => { + self.expose_hdrs.insert(method); + }, + Err(e) => { + self.error = Some(e.into()); + break + } + } + } + self + } + /// Set a maximum time for which this CORS request maybe cached. /// This value is set as the `Access-Control-Max-Age` header. /// @@ -449,13 +573,60 @@ impl CorsBuilder { #[cfg_attr(feature = "serialization", serde(default))] pub fn send_wildcard(&mut self) -> &mut CorsBuilder { if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.send_wildcards = true + cors.send_wildcard = true } self } - + + /// Allows users to make authenticated requests + /// + /// If true, injects the `Access-Control-Allow-Credentials` header in responses. + /// This allows cookies and credentials to be submitted across domains. + /// + /// This option cannot be used in conjuction with an `allowed_origin` set to `All` + /// and `send_wildcards` set to `true`. + /// + /// Defaults to `false`. + pub fn supports_credentials(&mut self) -> &mut CorsBuilder { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.supports_credentials = true + } + self + } + + /// Disable `Vary` header support. + /// + /// When enabled the header `Vary: Origin` will be returned as per the W3 + /// implementation guidelines. + /// + /// Setting this header when the `Access-Control-Allow-Origin` is + /// dynamically generated (e.g. when there is more than one allowed + /// origin, and an Origin than '*' is returned) informs CDNs and other + /// caches that the CORS headers are dynamic, and cannot be cached. + /// + /// By default `vary` header support is enabled. + pub fn disable_vary_header(&mut self) -> &mut CorsBuilder { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.vary_header = false + } + self + } + + /// Disable *preflight* request support. + /// + /// When enabled cors middleware automatically handles *OPTIONS* request. + /// This is useful application level middleware. + /// + /// By default *preflight* support is enabled. + pub fn disable_preflight(&mut self) -> &mut CorsBuilder { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.preflight = false + } + self + } + /// Finishes building and returns the built `Cors` instance. - pub fn finish(&mut self) -> Result { + pub fn finish(&mut self) -> Result { if !self.methods { self.allowed_methods(vec![Method::GET, Method::HEAD, Method::POST, Method::OPTIONS, Method::PUT, @@ -463,9 +634,114 @@ impl CorsBuilder { } if let Some(e) = self.error.take() { - return Err(e) + return Err(BuilderError::ParseError(e)) } - Ok(self.cors.take().expect("cannot reuse CorsBuilder")) + let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); + + if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { + return Err(BuilderError::CredentialsWithWildcardOrigin) + } + + if let AllOrSome::Some(ref origins) = cors.origins { + let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v)); + cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap()); + } + + if !self.expose_hdrs.is_empty() { + cors.expose_hdrs = Some( + self.expose_hdrs.iter().fold( + String::new(), |s, v| s + v.as_str())[1..].to_owned()); + } + Ok(cors) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use test::TestRequest; + + impl Started { + fn is_done(&self) -> bool { + match *self { + Started::Done => true, + _ => false, + } + } + fn response(self) -> HttpResponse { + match self { + Started::Response(resp) => resp, + _ => panic!(), + } + } + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn cors_validates_illegal_allow_credentials() { + Cors::build() + .supports_credentials() + .send_wildcard() + .finish() + .unwrap(); + } + + #[test] + fn validate_origin_allows_all_origins() { + let cors = Cors::default(); + let mut req = TestRequest::with_header( + "Origin", "https://www.example.com").finish(); + + assert!(cors.start(&mut req).ok().unwrap().is_done()) + } + + #[test] + fn test_preflight() { + let mut cors = Cors::build() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish().unwrap(); + + let mut req = TestRequest::with_header( + "Origin", "https://www.example.com") + .method(Method::OPTIONS) + .finish(); + + assert!(cors.start(&mut req).is_err()); + + let mut req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .finish(); + + assert!(cors.start(&mut req).is_err()); + + let mut req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT") + .method(Method::OPTIONS) + .finish(); + + let resp = cors.start(&mut req).unwrap().response(); + assert_eq!( + &b"*"[..], + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); + assert_eq!( + &b"3600"[..], + resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()); + //assert_eq!( + // &b"authorization,accept,content-type"[..], + // resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes()); + //assert_eq!( + // &b"POST,GET,OPTIONS"[..], + // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes()); + + cors.preflight = false; + assert!(cors.start(&mut req).unwrap().is_done()); } }