1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-02-02 18:59:04 +01:00

better names for cors errors

This commit is contained in:
Nikolay Kim 2018-01-10 14:20:00 -08:00
parent 615db0d9d8
commit 16e9512457

View File

@ -58,7 +58,7 @@ use middleware::{Middleware, Response, Started};
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
pub enum Error { pub enum CorsError {
/// The HTTP request header `Origin` is required but was not provided /// The HTTP request header `Origin` is required but was not provided
#[fail(display="The HTTP request header `Origin` is required but was not provided")] #[fail(display="The HTTP request header `Origin` is required but was not provided")]
MissingOrigin, MissingOrigin,
@ -91,7 +91,7 @@ pub enum Error {
/// A set of errors that can occur during building CORS middleware /// A set of errors that can occur during building CORS middleware
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
pub enum BuilderError { pub enum CorsBuilderError {
#[fail(display="Parse error: {}", _0)] #[fail(display="Parse error: {}", _0)]
ParseError(http::Error), ParseError(http::Error),
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
@ -102,7 +102,7 @@ pub enum BuilderError {
} }
impl ResponseError for Error { impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HTTPBadRequest.into() HTTPBadRequest.into()
@ -169,7 +169,7 @@ pub struct Cors {
impl Default for Cors { impl Default for Cors {
fn default() -> Cors { fn default() -> Cors {
Cors { Cors {
origins: AllOrSome::All, origins: AllOrSome::default(),
origins_str: None, origins_str: None,
methods: HashSet::from_iter( methods: HashSet::from_iter(
vec![Method::GET, Method::HEAD, vec![Method::GET, Method::HEAD,
@ -207,7 +207,7 @@ impl Cors {
} }
} }
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> { fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Some(hdr) = req.headers().get(header::ORIGIN) {
if let Ok(origin) = hdr.to_str() { if let Ok(origin) = hdr.to_str() {
if let Ok(uri) = Uri::try_from(origin) { if let Ok(uri) = Uri::try_from(origin) {
@ -217,33 +217,33 @@ impl Cors {
allowed_origins allowed_origins
.get(&uri) .get(&uri)
.and_then(|_| Some(())) .and_then(|_| Some(()))
.ok_or_else(|| Error::OriginNotAllowed) .ok_or_else(|| CorsError::OriginNotAllowed)
} }
}; };
} }
} }
Err(Error::BadOrigin) Err(CorsError::BadOrigin)
} else { } else {
Ok(()) Ok(())
} }
} }
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> { fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() { if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) { if let Ok(method) = Method::try_from(meth) {
return self.methods.get(&method) return self.methods.get(&method)
.and_then(|_| Some(())) .and_then(|_| Some(()))
.ok_or_else(|| Error::MethodNotAllowed); .ok_or_else(|| CorsError::MethodNotAllowed);
} }
} }
Err(Error::BadRequestMethod) Err(CorsError::BadRequestMethod)
} else { } else {
Err(Error::MissingRequestMethod) Err(CorsError::MissingRequestMethod)
} }
} }
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> { fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
if let Ok(headers) = hdr.to_str() { if let Ok(headers) = hdr.to_str() {
match self.headers { match self.headers {
@ -253,20 +253,20 @@ impl Cors {
for hdr in headers.split(',') { for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) { match HeaderName::try_from(hdr.trim()) {
Ok(hdr) => hdrs.insert(hdr), Ok(hdr) => hdrs.insert(hdr),
Err(_) => return Err(Error::BadRequestHeaders) Err(_) => return Err(CorsError::BadRequestHeaders)
}; };
} }
if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) {
return Err(Error::HeadersNotAllowed) return Err(CorsError::HeadersNotAllowed)
} }
return Ok(()) return Ok(())
} }
} }
} }
Err(Error::BadRequestHeaders) Err(CorsError::BadRequestHeaders)
} else { } else {
Err(Error::MissingRequestHeaders) Err(CorsError::MissingRequestHeaders)
} }
} }
} }
@ -626,7 +626,7 @@ impl CorsBuilder {
} }
/// Finishes building and returns the built `Cors` instance. /// Finishes building and returns the built `Cors` instance.
pub fn finish(&mut self) -> Result<Cors, BuilderError> { pub fn finish(&mut self) -> Result<Cors, CorsBuilderError> {
if !self.methods { if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD, self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT, Method::POST, Method::OPTIONS, Method::PUT,
@ -634,13 +634,13 @@ impl CorsBuilder {
} }
if let Some(e) = self.error.take() { if let Some(e) = self.error.take() {
return Err(BuilderError::ParseError(e)) return Err(CorsBuilderError::ParseError(e))
} }
let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
return Err(BuilderError::CredentialsWithWildcardOrigin) return Err(CorsBuilderError::CredentialsWithWildcardOrigin)
} }
if let AllOrSome::Some(ref origins) = cors.origins { if let AllOrSome::Some(ref origins) = cors.origins {
@ -744,4 +744,17 @@ mod tests {
cors.preflight = false; cors.preflight = false;
assert!(cors.start(&mut req).unwrap().is_done()); assert!(cors.start(&mut req).unwrap().is_done());
} }
#[test]
fn test_validate_origin() {
let cors = Cors::build()
.allowed_origin("http://www.example.com").finish().unwrap();
let mut req = TestRequest::with_header(
"Origin", "https://www.unknown.com")
.method(Method::GET)
.finish();
assert!(cors.start(&mut req).is_err());
}
} }