diff --git a/CHANGES.md b/CHANGES.md index bdc50fc51..89c85b982 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,7 @@ ### Fixed * HTTP1 decoder should perform case-insentive comparison for client requests (e.g. `Keep-Alive`). #631 +* Access-Control-Allow-Origin header should only a return a single, matching origin. #603 ## [0.7.16] - 2018-12-11 diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 953f2911c..386d00078 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -442,11 +442,23 @@ impl Middleware for Cors { .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); } } - AllOrSome::Some(_) => { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.inner.origins_str.as_ref().unwrap().clone(), - ); + AllOrSome::Some(ref origins) => { + if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| { + match o.to_str() { + Ok(os) => origins.contains(os), + _ => false + } + }) { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + } else { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + self.inner.origins_str.as_ref().unwrap().clone() + ); + }; } } @@ -1134,17 +1146,10 @@ mod tests { .to_str() .unwrap(); - if origins_str.starts_with("https://www.example.com") { - assert_eq!( - "https://www.example.com, https://www.google.com", - origins_str - ); - } else { - assert_eq!( - "https://www.google.com, https://www.example.com", - origins_str - ); - } + assert_eq!( + "https://www.example.com", + origins_str + ); } #[test] @@ -1180,4 +1185,43 @@ mod tests { let response = srv.execute(request.send()).unwrap(); assert_eq!(response.status(), StatusCode::OK); } + + #[test] + fn test_multiple_origins() { + let cors = Cors::build() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish(); + + + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .finish(); + let resp: HttpResponse = HttpResponse::Ok().into(); + + let resp = cors.response(&req, resp).unwrap().response(); + print!("{:?}", resp); + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .finish(); + let resp: HttpResponse = HttpResponse::Ok().into(); + + let resp = cors.response(&req, resp).unwrap().response(); + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } }