1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 09:42:40 +01:00

CorsBuilder::finish() panics on any configuration error

This commit is contained in:
Nikolay Kim 2018-04-09 14:20:12 -07:00
parent 7df2d6b12a
commit be358db422

View File

@ -34,7 +34,7 @@
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE) //! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600) //! .max_age(3600)
//! .finish().expect("Can not create CORS middleware") //! .finish()
//! .register(r); // <- Register CORS middleware //! .register(r); // <- Register CORS middleware
//! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); //! r.method(http::Method::GET).f(|_| HttpResponse::Ok());
//! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); //! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed());
@ -47,6 +47,7 @@
//! Cors middleware automatically handle *OPTIONS* preflight request. //! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc;
use http::{self, Method, HttpTryFrom, Uri, StatusCode}; use http::{self, Method, HttpTryFrom, Uri, StatusCode};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
@ -91,19 +92,6 @@ pub enum CorsError {
HeadersNotAllowed, HeadersNotAllowed,
} }
/// A set of errors that can occur during building CORS middleware
#[derive(Debug, Fail)]
pub enum CorsBuilderError {
#[fail(display="Parse error: {}", _0)]
ParseError(http::Error),
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
///
/// This is a misconfiguration. Check the documentation for `Cors`.
#[fail(display="Credentials are allowed, but the Origin is set to \"*\"")]
CredentialsWithWildcardOrigin,
}
impl ResponseError for CorsError { impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
@ -155,7 +143,12 @@ impl<T> AllOrSome<T> {
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
/// for responses to be generated. /// for responses to be generated.
#[derive(Clone)]
pub struct Cors { pub struct Cors {
inner: Rc<Inner>,
}
struct Inner {
methods: HashSet<Method>, methods: HashSet<Method>,
origins: AllOrSome<HashSet<String>>, origins: AllOrSome<HashSet<String>>,
origins_str: Option<HeaderValue>, origins_str: Option<HeaderValue>,
@ -170,7 +163,7 @@ pub struct Cors {
impl Default for Cors { impl Default for Cors {
fn default() -> Cors { fn default() -> Cors {
Cors { let inner = Inner {
origins: AllOrSome::default(), origins: AllOrSome::default(),
origins_str: None, origins_str: None,
methods: HashSet::from_iter( methods: HashSet::from_iter(
@ -184,14 +177,15 @@ impl Default for Cors {
send_wildcard: false, send_wildcard: false,
supports_credentials: false, supports_credentials: false,
vary_header: true, vary_header: true,
} };
Cors{inner: Rc::new(inner)}
} }
} }
impl Cors { impl Cors {
pub fn build() -> CorsBuilder { pub fn build() -> CorsBuilder {
CorsBuilder { CorsBuilder {
cors: Some(Cors { cors: Some(Inner {
origins: AllOrSome::All, origins: AllOrSome::All,
origins_str: None, origins_str: None,
methods: HashSet::new(), methods: HashSet::new(),
@ -223,7 +217,7 @@ impl Cors {
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Some(hdr) = req.headers().get(header::ORIGIN) {
if let Ok(origin) = hdr.to_str() { if let Ok(origin) = hdr.to_str() {
return match self.origins { return match self.inner.origins {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => { AllOrSome::Some(ref allowed_origins) => {
allowed_origins allowed_origins
@ -235,7 +229,7 @@ impl Cors {
} }
Err(CorsError::BadOrigin) Err(CorsError::BadOrigin)
} else { } else {
return match self.origins { return match self.inner.origins {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
_ => Err(CorsError::MissingOrigin) _ => Err(CorsError::MissingOrigin)
} }
@ -246,7 +240,7 @@ impl Cors {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() { if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) { if let Ok(method) = Method::try_from(meth) {
return self.methods.get(&method) return self.inner.methods.get(&method)
.and_then(|_| Some(())) .and_then(|_| Some(()))
.ok_or_else(|| CorsError::MethodNotAllowed); .ok_or_else(|| CorsError::MethodNotAllowed);
} }
@ -258,7 +252,7 @@ impl Cors {
} }
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
match self.headers { match self.inner.headers {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => { AllOrSome::Some(ref allowed_headers) => {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
@ -288,13 +282,13 @@ impl Cors {
impl<S> Middleware<S> for Cors { impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
if self.preflight && Method::OPTIONS == *req.method() { if self.inner.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?; self.validate_origin(req)?;
self.validate_allowed_method(req)?; self.validate_allowed_method(req)?;
self.validate_allowed_headers(req)?; self.validate_allowed_headers(req)?;
// allowed headers // allowed headers
let headers = if let Some(headers) = self.headers.as_ref() { let headers = if let Some(headers) = self.inner.headers.as_ref() {
Some(HeaderValue::try_from(&headers.iter().fold( Some(HeaderValue::try_from(&headers.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap()) String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap())
} else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { } else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
@ -305,13 +299,13 @@ impl<S> Middleware<S> for Cors {
Ok(Started::Response( Ok(Started::Response(
HttpResponse::Ok() HttpResponse::Ok()
.if_some(self.max_age.as_ref(), |max_age, resp| { .if_some(self.inner.max_age.as_ref(), |max_age, resp| {
let _ = resp.header( let _ = resp.header(
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
.if_some(headers, |headers, resp| { .if_some(headers, |headers, resp| {
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); }) let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); })
.if_true(self.origins.is_all(), |resp| { .if_true(self.inner.origins.is_all(), |resp| {
if self.send_wildcard { if self.inner.send_wildcard {
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
} else { } else {
let origin = req.headers().get(header::ORIGIN).unwrap(); let origin = req.headers().get(header::ORIGIN).unwrap();
@ -319,17 +313,17 @@ impl<S> Middleware<S> for Cors {
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
} }
}) })
.if_true(self.origins.is_some(), |resp| { .if_true(self.inner.origins.is_some(), |resp| {
resp.header( resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN, header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone()); self.inner.origins_str.as_ref().unwrap().clone());
}) })
.if_true(self.supports_credentials, |resp| { .if_true(self.inner.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}) })
.header( .header(
header::ACCESS_CONTROL_ALLOW_METHODS, header::ACCESS_CONTROL_ALLOW_METHODS,
&self.methods.iter().fold( &self.inner.methods.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]) String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
.finish())) .finish()))
} else { } else {
@ -340,9 +334,9 @@ impl<S> Middleware<S> for Cors {
} }
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> { fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
match self.origins { match self.inner.origins {
AllOrSome::All => { AllOrSome::All => {
if self.send_wildcard { if self.inner.send_wildcard {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
} else if let Some(origin) = req.headers().get(header::ORIGIN) { } else if let Some(origin) = req.headers().get(header::ORIGIN) {
@ -353,20 +347,20 @@ impl<S> Middleware<S> for Cors {
AllOrSome::Some(_) => { AllOrSome::Some(_) => {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone()); self.inner.origins_str.as_ref().unwrap().clone());
} }
} }
if let Some(ref expose) = self.expose_hdrs { if let Some(ref expose) = self.inner.expose_hdrs {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS, header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap()); HeaderValue::try_from(expose.as_str()).unwrap());
} }
if self.supports_credentials { if self.inner.supports_credentials {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
} }
if self.vary_header { if self.inner.vary_header {
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8); let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
@ -404,17 +398,19 @@ impl<S> Middleware<S> for Cors {
/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) /// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
/// .allowed_header(header::CONTENT_TYPE) /// .allowed_header(header::CONTENT_TYPE)
/// .max_age(3600) /// .max_age(3600)
/// .finish().unwrap(); /// .finish();
/// # } /// # }
/// ``` /// ```
pub struct CorsBuilder { pub struct CorsBuilder {
cors: Option<Cors>, cors: Option<Inner>,
methods: bool, methods: bool,
error: Option<http::Error>, error: Option<http::Error>,
expose_hdrs: HashSet<HeaderName>, expose_hdrs: HashSet<HeaderName>,
} }
fn cors<'a>(parts: &'a mut Option<Cors>, err: &Option<http::Error>) -> Option<&'a mut Cors> { fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<http::Error>)
-> Option<&'a mut Inner>
{
if err.is_some() { if err.is_some() {
return None return None
} }
@ -437,6 +433,8 @@ impl CorsBuilder {
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// Defaults to `All`. /// Defaults to `All`.
///
/// Builder panics if supplied origin is not valid uri.
pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder { pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match Uri::try_from(origin) { match Uri::try_from(origin) {
@ -602,6 +600,9 @@ impl CorsBuilder {
/// and `send_wildcards` set to `true`. /// and `send_wildcards` set to `true`.
/// ///
/// Defaults to `false`. /// Defaults to `false`.
///
/// Builder panics if credentials are allowed, but the Origin is set to "*".
/// This is not allowed by W3C
pub fn supports_credentials(&mut self) -> &mut CorsBuilder { pub fn supports_credentials(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.supports_credentials = true cors.supports_credentials = true
@ -641,7 +642,9 @@ impl CorsBuilder {
} }
/// Finishes building and returns the built `Cors` instance. /// Finishes building and returns the built `Cors` instance.
pub fn finish(&mut self) -> Result<Cors, CorsBuilderError> { ///
/// This method panics in case of any configuration error.
pub fn finish(&mut self) -> Cors {
if !self.methods { if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD, self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT, Method::POST, Method::OPTIONS, Method::PUT,
@ -649,13 +652,13 @@ impl CorsBuilder {
} }
if let Some(e) = self.error.take() { if let Some(e) = self.error.take() {
return Err(CorsBuilderError::ParseError(e)) panic!("{}", e);
} }
let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
return Err(CorsBuilderError::CredentialsWithWildcardOrigin) panic!("Credentials are allowed, but the Origin is set to \"*\"");
} }
if let AllOrSome::Some(ref origins) = cors.origins { if let AllOrSome::Some(ref origins) = cors.origins {
@ -668,7 +671,7 @@ impl CorsBuilder {
self.expose_hdrs.iter().fold( self.expose_hdrs.iter().fold(
String::new(), |s, v| s + v.as_str())[1..].to_owned()); String::new(), |s, v| s + v.as_str())[1..].to_owned());
} }
Ok(cors) Cors{inner: Rc::new(cors)}
} }
} }
@ -702,13 +705,12 @@ mod tests {
} }
#[test] #[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() { fn cors_validates_illegal_allow_credentials() {
Cors::build() Cors::build()
.supports_credentials() .supports_credentials()
.send_wildcard() .send_wildcard()
.finish() .finish();
.unwrap();
} }
#[test] #[test]
@ -728,7 +730,7 @@ mod tests {
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish().unwrap(); .finish();
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
"Origin", "https://www.example.com") "Origin", "https://www.example.com")
@ -764,7 +766,7 @@ mod tests {
// &b"POST,GET,OPTIONS"[..], // &b"POST,GET,OPTIONS"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes()); // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
cors.preflight = false; Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
assert!(cors.start(&mut req).unwrap().is_done()); assert!(cors.start(&mut req).unwrap().is_done());
} }
@ -772,7 +774,7 @@ mod tests {
#[should_panic(expected = "MissingOrigin")] #[should_panic(expected = "MissingOrigin")]
fn test_validate_missing_origin() { fn test_validate_missing_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap(); .allowed_origin("https://www.example.com").finish();
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
cors.start(&mut req).unwrap(); cors.start(&mut req).unwrap();
@ -782,7 +784,7 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap(); .allowed_origin("https://www.example.com").finish();
let mut req = TestRequest::with_header("Origin", "https://www.unknown.com") let mut req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
@ -793,7 +795,7 @@ mod tests {
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap(); .allowed_origin("https://www.example.com").finish();
let mut req = TestRequest::with_header("Origin", "https://www.example.com") let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
@ -804,7 +806,7 @@ mod tests {
#[test] #[test]
fn test_no_origin_response() { fn test_no_origin_response() {
let cors = Cors::build().finish().unwrap(); let cors = Cors::build().finish();
let mut req = TestRequest::default().method(Method::GET).finish(); let mut req = TestRequest::default().method(Method::GET).finish();
let resp: HttpResponse = HttpResponse::Ok().into(); let resp: HttpResponse = HttpResponse::Ok().into();
@ -830,7 +832,7 @@ mod tests {
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish().unwrap(); .finish();
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
"Origin", "https://www.example.com") "Origin", "https://www.example.com")
@ -857,7 +859,7 @@ mod tests {
let cors = Cors::build() let cors = Cors::build()
.disable_vary_header() .disable_vary_header()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish().unwrap(); .finish();
let resp: HttpResponse = HttpResponse::Ok().into(); let resp: HttpResponse = HttpResponse::Ok().into();
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!( assert_eq!(