1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

CorsBuilder::finish() panics on any configuration error

This commit is contained in:
Nikolay Kim 2018-04-09 14:20:12 -07:00
parent 7df2d6b12a
commit be358db422

View File

@ -34,7 +34,7 @@
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600)
//! .finish().expect("Can not create CORS middleware")
//! .finish()
//! .register(r); // <- Register CORS middleware
//! r.method(http::Method::GET).f(|_| HttpResponse::Ok());
//! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed());
@ -47,6 +47,7 @@
//! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet;
use std::iter::FromIterator;
use std::rc::Rc;
use http::{self, Method, HttpTryFrom, Uri, StatusCode};
use http::header::{self, HeaderName, HeaderValue};
@ -91,19 +92,6 @@ pub enum CorsError {
HeadersNotAllowed,
}
/// A set of errors that can occur during building CORS middleware
#[derive(Debug, Fail)]
pub enum CorsBuilderError {
#[fail(display="Parse error: {}", _0)]
ParseError(http::Error),
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
///
/// This is a misconfiguration. Check the documentation for `Cors`.
#[fail(display="Credentials are allowed, but the Origin is set to \"*\"")]
CredentialsWithWildcardOrigin,
}
impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse {
@ -155,7 +143,12 @@ impl<T> AllOrSome<T> {
///
/// The Cors struct contains the settings for CORS requests to be validated and
/// for responses to be generated.
#[derive(Clone)]
pub struct Cors {
inner: Rc<Inner>,
}
struct Inner {
methods: HashSet<Method>,
origins: AllOrSome<HashSet<String>>,
origins_str: Option<HeaderValue>,
@ -170,7 +163,7 @@ pub struct Cors {
impl Default for Cors {
fn default() -> Cors {
Cors {
let inner = Inner {
origins: AllOrSome::default(),
origins_str: None,
methods: HashSet::from_iter(
@ -184,14 +177,15 @@ impl Default for Cors {
send_wildcard: false,
supports_credentials: false,
vary_header: true,
}
};
Cors{inner: Rc::new(inner)}
}
}
impl Cors {
pub fn build() -> CorsBuilder {
CorsBuilder {
cors: Some(Cors {
cors: Some(Inner {
origins: AllOrSome::All,
origins_str: None,
methods: HashSet::new(),
@ -223,7 +217,7 @@ impl Cors {
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ORIGIN) {
if let Ok(origin) = hdr.to_str() {
return match self.origins {
return match self.inner.origins {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => {
allowed_origins
@ -235,7 +229,7 @@ impl Cors {
}
Err(CorsError::BadOrigin)
} else {
return match self.origins {
return match self.inner.origins {
AllOrSome::All => Ok(()),
_ => Err(CorsError::MissingOrigin)
}
@ -246,7 +240,7 @@ impl Cors {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) {
return self.methods.get(&method)
return self.inner.methods.get(&method)
.and_then(|_| Some(()))
.ok_or_else(|| CorsError::MethodNotAllowed);
}
@ -258,7 +252,7 @@ impl Cors {
}
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
match self.headers {
match self.inner.headers {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
@ -288,13 +282,13 @@ impl Cors {
impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
if self.preflight && Method::OPTIONS == *req.method() {
if self.inner.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?;
self.validate_allowed_method(req)?;
self.validate_allowed_headers(req)?;
// allowed headers
let headers = if let Some(headers) = self.headers.as_ref() {
let headers = if let Some(headers) = self.inner.headers.as_ref() {
Some(HeaderValue::try_from(&headers.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap())
} else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
@ -305,13 +299,13 @@ impl<S> Middleware<S> for Cors {
Ok(Started::Response(
HttpResponse::Ok()
.if_some(self.max_age.as_ref(), |max_age, resp| {
.if_some(self.inner.max_age.as_ref(), |max_age, resp| {
let _ = resp.header(
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
.if_some(headers, |headers, resp| {
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); })
.if_true(self.origins.is_all(), |resp| {
if self.send_wildcard {
.if_true(self.inner.origins.is_all(), |resp| {
if self.inner.send_wildcard {
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
} else {
let origin = req.headers().get(header::ORIGIN).unwrap();
@ -319,17 +313,17 @@ impl<S> Middleware<S> for Cors {
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
})
.if_true(self.origins.is_some(), |resp| {
.if_true(self.inner.origins.is_some(), |resp| {
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone());
self.inner.origins_str.as_ref().unwrap().clone());
})
.if_true(self.supports_credentials, |resp| {
.if_true(self.inner.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
})
.header(
header::ACCESS_CONTROL_ALLOW_METHODS,
&self.methods.iter().fold(
&self.inner.methods.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
.finish()))
} else {
@ -340,9 +334,9 @@ impl<S> Middleware<S> for Cors {
}
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
match self.origins {
match self.inner.origins {
AllOrSome::All => {
if self.send_wildcard {
if self.inner.send_wildcard {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
@ -353,20 +347,20 @@ impl<S> Middleware<S> for Cors {
AllOrSome::Some(_) => {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.origins_str.as_ref().unwrap().clone());
self.inner.origins_str.as_ref().unwrap().clone());
}
}
if let Some(ref expose) = self.expose_hdrs {
if let Some(ref expose) = self.inner.expose_hdrs {
resp.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap());
}
if self.supports_credentials {
if self.inner.supports_credentials {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
}
if self.vary_header {
if self.inner.vary_header {
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes());
@ -404,17 +398,19 @@ impl<S> Middleware<S> for Cors {
/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
/// .allowed_header(header::CONTENT_TYPE)
/// .max_age(3600)
/// .finish().unwrap();
/// .finish();
/// # }
/// ```
pub struct CorsBuilder {
cors: Option<Cors>,
cors: Option<Inner>,
methods: bool,
error: Option<http::Error>,
expose_hdrs: HashSet<HeaderName>,
}
fn cors<'a>(parts: &'a mut Option<Cors>, err: &Option<http::Error>) -> Option<&'a mut Cors> {
fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<http::Error>)
-> Option<&'a mut Inner>
{
if err.is_some() {
return None
}
@ -437,6 +433,8 @@ impl CorsBuilder {
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// Defaults to `All`.
///
/// Builder panics if supplied origin is not valid uri.
pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) {
match Uri::try_from(origin) {
@ -602,6 +600,9 @@ impl CorsBuilder {
/// and `send_wildcards` set to `true`.
///
/// Defaults to `false`.
///
/// Builder panics if credentials are allowed, but the Origin is set to "*".
/// This is not allowed by W3C
pub fn supports_credentials(&mut self) -> &mut CorsBuilder {
if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.supports_credentials = true
@ -641,7 +642,9 @@ impl CorsBuilder {
}
/// Finishes building and returns the built `Cors` instance.
pub fn finish(&mut self) -> Result<Cors, CorsBuilderError> {
///
/// This method panics in case of any configuration error.
pub fn finish(&mut self) -> Cors {
if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT,
@ -649,13 +652,13 @@ impl CorsBuilder {
}
if let Some(e) = self.error.take() {
return Err(CorsBuilderError::ParseError(e))
panic!("{}", e);
}
let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
return Err(CorsBuilderError::CredentialsWithWildcardOrigin)
panic!("Credentials are allowed, but the Origin is set to \"*\"");
}
if let AllOrSome::Some(ref origins) = cors.origins {
@ -668,7 +671,7 @@ impl CorsBuilder {
self.expose_hdrs.iter().fold(
String::new(), |s, v| s + v.as_str())[1..].to_owned());
}
Ok(cors)
Cors{inner: Rc::new(cors)}
}
}
@ -702,13 +705,12 @@ mod tests {
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() {
Cors::build()
.supports_credentials()
.send_wildcard()
.finish()
.unwrap();
.finish();
}
#[test]
@ -728,7 +730,7 @@ mod tests {
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.finish().unwrap();
.finish();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
@ -764,7 +766,7 @@ mod tests {
// &b"POST,GET,OPTIONS"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
cors.preflight = false;
Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
assert!(cors.start(&mut req).unwrap().is_done());
}
@ -772,7 +774,7 @@ mod tests {
#[should_panic(expected = "MissingOrigin")]
fn test_validate_missing_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap();
.allowed_origin("https://www.example.com").finish();
let mut req = HttpRequest::default();
cors.start(&mut req).unwrap();
@ -782,7 +784,7 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap();
.allowed_origin("https://www.example.com").finish();
let mut req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET)
@ -793,7 +795,7 @@ mod tests {
#[test]
fn test_validate_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish().unwrap();
.allowed_origin("https://www.example.com").finish();
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET)
@ -804,7 +806,7 @@ mod tests {
#[test]
fn test_no_origin_response() {
let cors = Cors::build().finish().unwrap();
let cors = Cors::build().finish();
let mut req = TestRequest::default().method(Method::GET).finish();
let resp: HttpResponse = HttpResponse::Ok().into();
@ -830,7 +832,7 @@ mod tests {
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.finish().unwrap();
.finish();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
@ -857,7 +859,7 @@ mod tests {
let cors = Cors::build()
.disable_vary_header()
.allowed_origin("https://www.example.com")
.finish().unwrap();
.finish();
let resp: HttpResponse = HttpResponse::Ok().into();
let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!(