diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index a4a011d9..a2667430 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -58,7 +58,7 @@ use middleware::{Middleware, Response, Started}; /// A set of errors that can occur during processing CORS #[derive(Debug, Fail)] -pub enum Error { +pub enum CorsError { /// The HTTP request header `Origin` is required but was not provided #[fail(display="The HTTP request header `Origin` is required but was not provided")] MissingOrigin, @@ -91,7 +91,7 @@ pub enum Error { /// A set of errors that can occur during building CORS middleware #[derive(Debug, Fail)] -pub enum BuilderError { +pub enum CorsBuilderError { #[fail(display="Parse error: {}", _0)] ParseError(http::Error), /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C @@ -102,7 +102,7 @@ pub enum BuilderError { } -impl ResponseError for Error { +impl ResponseError for CorsError { fn error_response(&self) -> HttpResponse { HTTPBadRequest.into() @@ -169,7 +169,7 @@ pub struct Cors { impl Default for Cors { fn default() -> Cors { Cors { - origins: AllOrSome::All, + origins: AllOrSome::default(), origins_str: None, methods: HashSet::from_iter( vec![Method::GET, Method::HEAD, @@ -207,7 +207,7 @@ impl Cors { } } - fn validate_origin(&self, req: &mut HttpRequest) -> Result<(), Error> { + fn validate_origin(&self, req: &mut HttpRequest) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Ok(origin) = hdr.to_str() { if let Ok(uri) = Uri::try_from(origin) { @@ -217,33 +217,33 @@ impl Cors { allowed_origins .get(&uri) .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed) + .ok_or_else(|| CorsError::OriginNotAllowed) } }; } } - Err(Error::BadOrigin) + Err(CorsError::BadOrigin) } else { Ok(()) } } - fn validate_allowed_method(&self, req: &mut HttpRequest) -> Result<(), Error> { + fn validate_allowed_method(&self, req: &mut HttpRequest) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Ok(meth) = hdr.to_str() { if let Ok(method) = Method::try_from(meth) { return self.methods.get(&method) .and_then(|_| Some(())) - .ok_or_else(|| Error::MethodNotAllowed); + .ok_or_else(|| CorsError::MethodNotAllowed); } } - Err(Error::BadRequestMethod) + Err(CorsError::BadRequestMethod) } else { - Err(Error::MissingRequestMethod) + Err(CorsError::MissingRequestMethod) } } - fn validate_allowed_headers(&self, req: &mut HttpRequest) -> Result<(), Error> { + fn validate_allowed_headers(&self, req: &mut HttpRequest) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { if let Ok(headers) = hdr.to_str() { match self.headers { @@ -253,20 +253,20 @@ impl Cors { for hdr in headers.split(',') { match HeaderName::try_from(hdr.trim()) { Ok(hdr) => hdrs.insert(hdr), - Err(_) => return Err(Error::BadRequestHeaders) + Err(_) => return Err(CorsError::BadRequestHeaders) }; } if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { - return Err(Error::HeadersNotAllowed) + return Err(CorsError::HeadersNotAllowed) } return Ok(()) } } } - Err(Error::BadRequestHeaders) + Err(CorsError::BadRequestHeaders) } else { - Err(Error::MissingRequestHeaders) + Err(CorsError::MissingRequestHeaders) } } } @@ -626,7 +626,7 @@ impl CorsBuilder { } /// Finishes building and returns the built `Cors` instance. - pub fn finish(&mut self) -> Result { + pub fn finish(&mut self) -> Result { if !self.methods { self.allowed_methods(vec![Method::GET, Method::HEAD, Method::POST, Method::OPTIONS, Method::PUT, @@ -634,13 +634,13 @@ impl CorsBuilder { } if let Some(e) = self.error.take() { - return Err(BuilderError::ParseError(e)) + return Err(CorsBuilderError::ParseError(e)) } let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { - return Err(BuilderError::CredentialsWithWildcardOrigin) + return Err(CorsBuilderError::CredentialsWithWildcardOrigin) } if let AllOrSome::Some(ref origins) = cors.origins { @@ -744,4 +744,17 @@ mod tests { cors.preflight = false; assert!(cors.start(&mut req).unwrap().is_done()); } + + #[test] + fn test_validate_origin() { + let cors = Cors::build() + .allowed_origin("http://www.example.com").finish().unwrap(); + + let mut req = TestRequest::with_header( + "Origin", "https://www.unknown.com") + .method(Method::GET) + .finish(); + + assert!(cors.start(&mut req).is_err()); + } }