From cc7f6b5eef4d9e2f67265f421e8d1e39fdf70168 Mon Sep 17 00:00:00 2001 From: David McGuire Date: Sun, 10 Mar 2019 21:26:54 -0700 Subject: [PATCH] Fix preflight CORS header compliance; refactor previous patch. (#717) --- CHANGES.md | 2 + src/middleware/cors.rs | 119 +++++++++++++++++++++++++---------------- 2 files changed, 74 insertions(+), 47 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 8cce7159..76f3465e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -16,6 +16,8 @@ * Do not remove `Content-Length` on `Body::Empty` and insert zero value if it is missing for `POST` and `PUT` methods. +* Fix preflight CORS header compliance; refactor previous patch (#603). #717 + ## [0.7.18] - 2019-01-10 ### Added diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 386d0007..80ee5b19 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -307,6 +307,32 @@ impl Cors { } } + fn access_control_allow_origin(&self, req: &Request) -> Option { + match self.inner.origins { + AllOrSome::All => { + if self.inner.send_wildcard { + Some(HeaderValue::from_static("*")) + } else if let Some(origin) = req.headers().get(header::ORIGIN) { + Some(origin.clone()) + } else { + None + } + } + AllOrSome::Some(ref origins) => { + if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| { + match o.to_str() { + Ok(os) => origins.contains(os), + _ => false + } + }) { + Some(origin.clone()) + } else { + Some(self.inner.origins_str.as_ref().unwrap().clone()) + } + } + } + } + fn validate_allowed_method(&self, req: &Request) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Ok(meth) = hdr.to_str() { @@ -390,21 +416,9 @@ impl Middleware for Cors { }).if_some(headers, |headers, resp| { let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }).if_true(self.inner.origins.is_all(), |resp| { - if self.inner.send_wildcard { - resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); - } else { - let origin = req.headers().get(header::ORIGIN).unwrap(); - resp.header( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - origin.clone(), - ); - } - }).if_true(self.inner.origins.is_some(), |resp| { - resp.header( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.inner.origins_str.as_ref().unwrap().clone(), - ); + }).if_some(self.access_control_allow_origin(&req), |origin, resp| { + let _ = + resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); }).if_true(self.inner.supports_credentials, |resp| { resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); }).header( @@ -430,37 +444,11 @@ impl Middleware for Cors { fn response( &self, req: &HttpRequest, mut resp: HttpResponse, ) -> Result { - match self.inner.origins { - AllOrSome::All => { - if self.inner.send_wildcard { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - HeaderValue::from_static("*"), - ); - } else if let Some(origin) = req.headers().get(header::ORIGIN) { - resp.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - } - } - AllOrSome::Some(ref origins) => { - if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| { - match o.to_str() { - Ok(os) => origins.contains(os), - _ => false - } - }) { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - origin.clone(), - ); - } else { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.inner.origins_str.as_ref().unwrap().clone() - ); - }; - } - } + + if let Some(origin) = self.access_control_allow_origin(req) { + resp.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); + }; if let Some(ref expose) = self.inner.expose_hdrs { resp.headers_mut().insert( @@ -1201,7 +1189,6 @@ mod tests { let resp: HttpResponse = HttpResponse::Ok().into(); let resp = cors.response(&req, resp).unwrap().response(); - print!("{:?}", resp); assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1224,4 +1211,42 @@ mod tests { .as_bytes() ); } + + #[test] + fn test_multiple_origins_preflight() { + let cors = Cors::build() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish(); + + + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .finish(); + + let resp = cors.start(&req).ok().unwrap().response(); + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .finish(); + + let resp = cors.start(&req).ok().unwrap().response(); + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } }