diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index f251d6b4b..b79708453 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -224,7 +224,10 @@ impl Cors { } Err(CorsError::BadOrigin) } else { - Ok(()) + return match self.origins { + AllOrSome::All => Ok(()), + _ => Err(CorsError::MissingOrigin) + } } } @@ -677,6 +680,14 @@ mod tests { } } } + impl Response { + fn response(self) -> HttpResponse { + match self { + Response::Done(resp) => resp, + _ => panic!(), + } + } + } #[test] #[should_panic(expected = "CredentialsWithWildcardOrigin")] @@ -746,15 +757,31 @@ mod tests { } #[test] - fn test_validate_origin() { + #[should_panic(expected = "MissingOrigin")] + fn test_validate_missing_origin() { + let cors = Cors::build() + .allowed_origin("https://www.example.com").finish().unwrap(); + + let mut req = HttpRequest::default(); + cors.start(&mut req).unwrap(); + } + + #[test] + #[should_panic(expected = "OriginNotAllowed")] + fn test_validate_not_allowed_origin() { let cors = Cors::build() .allowed_origin("https://www.example.com").finish().unwrap(); let mut req = TestRequest::with_header("Origin", "https://www.unknown.com") .method(Method::GET) .finish(); + cors.start(&mut req).unwrap(); + } - assert!(cors.start(&mut req).is_err()); + #[test] + fn test_validate_origin() { + let cors = Cors::build() + .allowed_origin("https://www.example.com").finish().unwrap(); let mut req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::GET) @@ -762,4 +789,46 @@ mod tests { assert!(cors.start(&mut req).unwrap().is_done()); } + + #[test] + fn test_response() { + let cors = Cors::build() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish().unwrap(); + + let mut req = TestRequest::with_header( + "Origin", "https://www.example.com") + .method(Method::OPTIONS) + .finish(); + + let resp: HttpResponse = HTTPOk.into(); + let resp = cors.response(&mut req, resp).unwrap().response(); + assert_eq!( + &b"*"[..], + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes()); + + let resp: HttpResponse = HTTPOk.build() + .header(header::VARY, "Accept") + .finish().unwrap(); + let resp = cors.response(&mut req, resp).unwrap().response(); + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes()); + + let cors = Cors::build() + .allowed_origin("https://www.example.com") + .finish().unwrap(); + let resp: HttpResponse = HTTPOk.into(); + let resp = cors.response(&mut req, resp).unwrap().response(); + assert_eq!( + &b"https://www.example.com/"[..], + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); + } }