1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-17 21:34:32 +01:00

better names for cors errors

This commit is contained in:
Nikolay Kim 2018-01-10 14:20:00 -08:00
parent 615db0d9d8
commit 16e9512457

View File

@ -58,7 +58,7 @@ use middleware::{Middleware, Response, Started};
/// A set of errors that can occur during processing CORS
#[derive(Debug, Fail)]
pub enum Error {
pub enum CorsError {
/// The HTTP request header `Origin` is required but was not provided
#[fail(display="The HTTP request header `Origin` is required but was not provided")]
MissingOrigin,
@ -91,7 +91,7 @@ pub enum Error {
/// A set of errors that can occur during building CORS middleware
#[derive(Debug, Fail)]
pub enum BuilderError {
pub enum CorsBuilderError {
#[fail(display="Parse error: {}", _0)]
ParseError(http::Error),
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
@ -102,7 +102,7 @@ pub enum BuilderError {
}
impl ResponseError for Error {
impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse {
HTTPBadRequest.into()
@ -169,7 +169,7 @@ pub struct Cors {
impl Default for Cors {
fn default() -> Cors {
Cors {
origins: AllOrSome::All,
origins: AllOrSome::default(),
origins_str: None,
methods: HashSet::from_iter(
vec![Method::GET, Method::HEAD,
@ -207,7 +207,7 @@ impl Cors {
}
}
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> {
fn validate_origin<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ORIGIN) {
if let Ok(origin) = hdr.to_str() {
if let Ok(uri) = Uri::try_from(origin) {
@ -217,33 +217,33 @@ impl Cors {
allowed_origins
.get(&uri)
.and_then(|_| Some(()))
.ok_or_else(|| Error::OriginNotAllowed)
.ok_or_else(|| CorsError::OriginNotAllowed)
}
};
}
}
Err(Error::BadOrigin)
Err(CorsError::BadOrigin)
} else {
Ok(())
}
}
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> {
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) {
return self.methods.get(&method)
.and_then(|_| Some(()))
.ok_or_else(|| Error::MethodNotAllowed);
.ok_or_else(|| CorsError::MethodNotAllowed);
}
}
Err(Error::BadRequestMethod)
Err(CorsError::BadRequestMethod)
} else {
Err(Error::MissingRequestMethod)
Err(CorsError::MissingRequestMethod)
}
}
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), Error> {
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
if let Ok(headers) = hdr.to_str() {
match self.headers {
@ -253,20 +253,20 @@ impl Cors {
for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) {
Ok(hdr) => hdrs.insert(hdr),
Err(_) => return Err(Error::BadRequestHeaders)
Err(_) => return Err(CorsError::BadRequestHeaders)
};
}
if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) {
return Err(Error::HeadersNotAllowed)
return Err(CorsError::HeadersNotAllowed)
}
return Ok(())
}
}
}
Err(Error::BadRequestHeaders)
Err(CorsError::BadRequestHeaders)
} else {
Err(Error::MissingRequestHeaders)
Err(CorsError::MissingRequestHeaders)
}
}
}
@ -626,7 +626,7 @@ impl CorsBuilder {
}
/// Finishes building and returns the built `Cors` instance.
pub fn finish(&mut self) -> Result<Cors, BuilderError> {
pub fn finish(&mut self) -> Result<Cors, CorsBuilderError> {
if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT,
@ -634,13 +634,13 @@ impl CorsBuilder {
}
if let Some(e) = self.error.take() {
return Err(BuilderError::ParseError(e))
return Err(CorsBuilderError::ParseError(e))
}
let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
return Err(BuilderError::CredentialsWithWildcardOrigin)
return Err(CorsBuilderError::CredentialsWithWildcardOrigin)
}
if let AllOrSome::Some(ref origins) = cors.origins {
@ -744,4 +744,17 @@ mod tests {
cors.preflight = false;
assert!(cors.start(&mut req).unwrap().is_done());
}
#[test]
fn test_validate_origin() {
let cors = Cors::build()
.allowed_origin("http://www.example.com").finish().unwrap();
let mut req = TestRequest::with_header(
"Origin", "https://www.unknown.com")
.method(Method::GET)
.finish();
assert!(cors.start(&mut req).is_err());
}
}