mirror of
https://github.com/fafhrd91/actix-web
synced 2024-11-23 16:21:06 +01:00
Fix preflight CORS header compliance; refactor previous patch. (#717)
This commit is contained in:
parent
ceca96da28
commit
cc7f6b5eef
@ -16,6 +16,8 @@
|
||||
|
||||
* Do not remove `Content-Length` on `Body::Empty` and insert zero value if it is missing for `POST` and `PUT` methods.
|
||||
|
||||
* Fix preflight CORS header compliance; refactor previous patch (#603). #717
|
||||
|
||||
## [0.7.18] - 2019-01-10
|
||||
|
||||
### Added
|
||||
|
@ -307,6 +307,32 @@ impl Cors {
|
||||
}
|
||||
}
|
||||
|
||||
fn access_control_allow_origin(&self, req: &Request) -> Option<HeaderValue> {
|
||||
match self.inner.origins {
|
||||
AllOrSome::All => {
|
||||
if self.inner.send_wildcard {
|
||||
Some(HeaderValue::from_static("*"))
|
||||
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
|
||||
Some(origin.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(ref origins) => {
|
||||
if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| {
|
||||
match o.to_str() {
|
||||
Ok(os) => origins.contains(os),
|
||||
_ => false
|
||||
}
|
||||
}) {
|
||||
Some(origin.clone())
|
||||
} else {
|
||||
Some(self.inner.origins_str.as_ref().unwrap().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_allowed_method(&self, req: &Request) -> Result<(), CorsError> {
|
||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
|
||||
if let Ok(meth) = hdr.to_str() {
|
||||
@ -390,21 +416,9 @@ impl<S> Middleware<S> for Cors {
|
||||
}).if_some(headers, |headers, resp| {
|
||||
let _ =
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
|
||||
}).if_true(self.inner.origins.is_all(), |resp| {
|
||||
if self.inner.send_wildcard {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
|
||||
} else {
|
||||
let origin = req.headers().get(header::ORIGIN).unwrap();
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
origin.clone(),
|
||||
);
|
||||
}
|
||||
}).if_true(self.inner.origins.is_some(), |resp| {
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.inner.origins_str.as_ref().unwrap().clone(),
|
||||
);
|
||||
}).if_some(self.access_control_allow_origin(&req), |origin, resp| {
|
||||
let _ =
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||
}).if_true(self.inner.supports_credentials, |resp| {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
|
||||
}).header(
|
||||
@ -430,37 +444,11 @@ impl<S> Middleware<S> for Cors {
|
||||
fn response(
|
||||
&self, req: &HttpRequest<S>, mut resp: HttpResponse,
|
||||
) -> Result<Response> {
|
||||
match self.inner.origins {
|
||||
AllOrSome::All => {
|
||||
if self.inner.send_wildcard {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
|
||||
|
||||
if let Some(origin) = self.access_control_allow_origin(req) {
|
||||
resp.headers_mut()
|
||||
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(ref origins) => {
|
||||
if let Some(origin) = req.headers().get(header::ORIGIN).filter(|o| {
|
||||
match o.to_str() {
|
||||
Ok(os) => origins.contains(os),
|
||||
_ => false
|
||||
}
|
||||
}) {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
origin.clone(),
|
||||
);
|
||||
} else {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.inner.origins_str.as_ref().unwrap().clone()
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref expose) = self.inner.expose_hdrs {
|
||||
resp.headers_mut().insert(
|
||||
@ -1201,7 +1189,6 @@ mod tests {
|
||||
let resp: HttpResponse = HttpResponse::Ok().into();
|
||||
|
||||
let resp = cors.response(&req, resp).unwrap().response();
|
||||
print!("{:?}", resp);
|
||||
assert_eq!(
|
||||
&b"https://example.com"[..],
|
||||
resp.headers()
|
||||
@ -1224,4 +1211,42 @@ mod tests {
|
||||
.as_bytes()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_origins_preflight() {
|
||||
let cors = Cors::build()
|
||||
.allowed_origin("https://example.com")
|
||||
.allowed_origin("https://example.org")
|
||||
.allowed_methods(vec![Method::GET])
|
||||
.finish();
|
||||
|
||||
|
||||
let req = TestRequest::with_header("Origin", "https://example.com")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
let resp = cors.start(&req).ok().unwrap().response();
|
||||
assert_eq!(
|
||||
&b"https://example.com"[..],
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
|
||||
let req = TestRequest::with_header("Origin", "https://example.org")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
let resp = cors.start(&req).ok().unwrap().response();
|
||||
assert_eq!(
|
||||
&b"https://example.org"[..],
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user