1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-30 10:42:55 +01:00

Fix preflight CORS header compliance; refactor previous patch. (#717)

This commit is contained in:
David McGuire 2019-03-10 21:26:54 -07:00 committed by Douman
parent ceca96da28
commit cc7f6b5eef
2 changed files with 74 additions and 47 deletions

View File

@ -16,6 +16,8 @@
* Do not remove `Content-Length` on `Body::Empty` and insert zero value if it is missing for `POST` and `PUT` methods. * Do not remove `Content-Length` on `Body::Empty` and insert zero value if it is missing for `POST` and `PUT` methods.
* Fix preflight CORS header compliance; refactor previous patch (#603). #717
## [0.7.18] - 2019-01-10 ## [0.7.18] - 2019-01-10
### Added ### Added

View File

@ -307,6 +307,32 @@ impl Cors {
} }
} }
fn access_control_allow_origin(&self, req: &Request) -> Option<HeaderValue> {
match self.inner.origins {
AllOrSome::All => {
if self.inner.send_wildcard {
Some(HeaderValue::from_static("*"))
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
Some(origin.clone())
} else {
None
}
}
AllOrSome::Some(ref origins) => {
if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| {
match o.to_str() {
Ok(os) => origins.contains(os),
_ => false
}
}) {
Some(origin.clone())
} else {
Some(self.inner.origins_str.as_ref().unwrap().clone())
}
}
}
}
fn validate_allowed_method(&self, req: &Request) -> Result<(), CorsError> { fn validate_allowed_method(&self, req: &Request) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if let Ok(meth) = hdr.to_str() { if let Ok(meth) = hdr.to_str() {
@ -390,21 +416,9 @@ impl<S> Middleware<S> for Cors {
}).if_some(headers, |headers, resp| { }).if_some(headers, |headers, resp| {
let _ = let _ =
resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
}).if_true(self.inner.origins.is_all(), |resp| { }).if_some(self.access_control_allow_origin(&req), |origin, resp| {
if self.inner.send_wildcard { let _ =
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
} else {
let origin = req.headers().get(header::ORIGIN).unwrap();
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
}
}).if_true(self.inner.origins.is_some(), |resp| {
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone(),
);
}).if_true(self.inner.supports_credentials, |resp| { }).if_true(self.inner.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}).header( }).header(
@ -430,37 +444,11 @@ impl<S> Middleware<S> for Cors {
fn response( fn response(
&self, req: &HttpRequest<S>, mut resp: HttpResponse, &self, req: &HttpRequest<S>, mut resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
match self.inner.origins {
AllOrSome::All => { if let Some(origin) = self.access_control_allow_origin(req) {
if self.inner.send_wildcard {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("*"),
);
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
resp.headers_mut() resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
}
AllOrSome::Some(ref origins) => {
if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| {
match o.to_str() {
Ok(os) => origins.contains(os),
_ => false
}
}) {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
} else {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone()
);
}; };
}
}
if let Some(ref expose) = self.inner.expose_hdrs { if let Some(ref expose) = self.inner.expose_hdrs {
resp.headers_mut().insert( resp.headers_mut().insert(
@ -1201,7 +1189,6 @@ mod tests {
let resp: HttpResponse = HttpResponse::Ok().into(); let resp: HttpResponse = HttpResponse::Ok().into();
let resp = cors.response(&req, resp).unwrap().response(); let resp = cors.response(&req, resp).unwrap().response();
print!("{:?}", resp);
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@ -1224,4 +1211,42 @@ mod tests {
.as_bytes() .as_bytes()
); );
} }
#[test]
fn test_multiple_origins_preflight() {
let cors = Cors::build()
.allowed_origin("https://example.com")
.allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET])
.finish();
let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS)
.finish();
let resp = cors.start(&req).ok().unwrap().response();
assert_eq!(
&b"https://example.com"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
let req = TestRequest::with_header("Origin", "https://example.org")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS)
.finish();
let resp = cors.start(&req).ok().unwrap().response();
assert_eq!(
&b"https://example.org"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
}
} }