diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 3549ba11b..454596567 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -424,7 +424,10 @@ impl Middleware for Cors { .finish(), )) } else { - self.validate_origin(req)?; + // Only check requests with a origin header. + if req.headers().contains_key(header::ORIGIN) { + self.validate_origin(req)?; + } Ok(Started::Done) } @@ -1007,16 +1010,15 @@ mod tests { assert!(cors.start(&mut req).unwrap().is_done()); } - #[test] - #[should_panic(expected = "MissingOrigin")] - fn test_validate_missing_origin() { - let mut cors = Cors::build() - .allowed_origin("https://www.example.com") - .finish(); - - let mut req = HttpRequest::default(); - cors.start(&mut req).unwrap(); - } + // #[test] + // #[should_panic(expected = "MissingOrigin")] + // fn test_validate_missing_origin() { + // let mut cors = Cors::build() + // .allowed_origin("https://www.example.com") + // .finish(); + // let mut req = HttpRequest::default(); + // cors.start(&mut req).unwrap(); + // } #[test] #[should_panic(expected = "OriginNotAllowed")] @@ -1133,10 +1135,19 @@ mod tests { }) }); - let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let request = srv + .get() + .uri(srv.url("/test")) + .header("ORIGIN", "https://www.example2.com") + .finish() + .unwrap(); let response = srv.execute(request.send()).unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request = srv .get() .uri(srv.url("/test"))