mirror of
https://github.com/actix/actix-extras.git
synced 2025-01-23 15:24:36 +01:00
complete cors implementation
This commit is contained in:
parent
3f3dcf413b
commit
615db0d9d8
@ -396,7 +396,7 @@ impl HttpResponseBuilder {
|
||||
|
||||
/// This method calls provided closure with builder reference if value is true.
|
||||
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
|
||||
where F: Fn(&mut HttpResponseBuilder) + 'static
|
||||
where F: FnOnce(&mut HttpResponseBuilder)
|
||||
{
|
||||
if value {
|
||||
f(self);
|
||||
@ -406,7 +406,7 @@ impl HttpResponseBuilder {
|
||||
|
||||
/// This method calls provided closure with builder reference if value is Some.
|
||||
pub fn if_some<T, F>(&mut self, value: Option<&T>, f: F) -> &mut Self
|
||||
where F: Fn(&T, &mut HttpResponseBuilder) + 'static
|
||||
where F: FnOnce(&T, &mut HttpResponseBuilder)
|
||||
{
|
||||
if let Some(val) = value {
|
||||
f(val, self);
|
||||
|
@ -45,15 +45,16 @@
|
||||
//!
|
||||
//! Cors middleware automatically handle *OPTIONS* preflight request.
|
||||
use std::collections::HashSet;
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use http::{self, Method, HttpTryFrom, Uri};
|
||||
use http::header::{self, HeaderName};
|
||||
use http::header::{self, HeaderName, HeaderValue};
|
||||
|
||||
use error::{Result, ResponseError};
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Middleware, Response, Started};
|
||||
use httpcodes::{HTTPOk, HTTPBadRequest};
|
||||
use middleware::{Middleware, Response, Started};
|
||||
|
||||
/// A set of errors that can occur during processing CORS
|
||||
#[derive(Debug, Fail)]
|
||||
@ -86,6 +87,13 @@ pub enum Error {
|
||||
/// One or more headers requested are not allowed
|
||||
#[fail(display="One or more headers requested are not allowed")]
|
||||
HeadersNotAllowed,
|
||||
}
|
||||
|
||||
/// A set of errors that can occur during building CORS middleware
|
||||
#[derive(Debug, Fail)]
|
||||
pub enum BuilderError {
|
||||
#[fail(display="Parse error: {}", _0)]
|
||||
ParseError(http::Error),
|
||||
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
|
||||
///
|
||||
/// This is a misconfiguration. Check the docuemntation for `Cors`.
|
||||
@ -93,6 +101,7 @@ pub enum Error {
|
||||
CredentialsWithWildcardOrigin,
|
||||
}
|
||||
|
||||
|
||||
impl ResponseError for Error {
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
@ -147,9 +156,34 @@ impl<T> AllOrSome<T> {
|
||||
pub struct Cors {
|
||||
methods: HashSet<Method>,
|
||||
origins: AllOrSome<HashSet<Uri>>,
|
||||
origins_str: Option<HeaderValue>,
|
||||
headers: AllOrSome<HashSet<HeaderName>>,
|
||||
expose_hdrs: Option<String>,
|
||||
max_age: Option<usize>,
|
||||
send_wildcards: bool,
|
||||
preflight: bool,
|
||||
send_wildcard: bool,
|
||||
supports_credentials: bool,
|
||||
vary_header: bool,
|
||||
}
|
||||
|
||||
impl Default for Cors {
|
||||
fn default() -> Cors {
|
||||
Cors {
|
||||
origins: AllOrSome::All,
|
||||
origins_str: None,
|
||||
methods: HashSet::from_iter(
|
||||
vec![Method::GET, Method::HEAD,
|
||||
Method::POST, Method::OPTIONS, Method::PUT,
|
||||
Method::PATCH, Method::DELETE].into_iter()),
|
||||
headers: AllOrSome::All,
|
||||
expose_hdrs: None,
|
||||
max_age: None,
|
||||
preflight: true,
|
||||
send_wildcard: false,
|
||||
supports_credentials: false,
|
||||
vary_header: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Cors {
|
||||
@ -157,13 +191,19 @@ impl Cors {
|
||||
CorsBuilder {
|
||||
cors: Some(Cors {
|
||||
origins: AllOrSome::All,
|
||||
origins_str: None,
|
||||
methods: HashSet::new(),
|
||||
headers: AllOrSome::All,
|
||||
expose_hdrs: None,
|
||||
max_age: None,
|
||||
send_wildcards: false,
|
||||
preflight: true,
|
||||
send_wildcard: false,
|
||||
supports_credentials: false,
|
||||
vary_header: true,
|
||||
}),
|
||||
methods: false,
|
||||
error: None,
|
||||
expose_hdrs: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,31 +274,90 @@ impl Cors {
|
||||
impl<S> Middleware<S> for Cors {
|
||||
|
||||
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||
if Method::OPTIONS == *req.method() {
|
||||
if self.preflight && Method::OPTIONS == *req.method() {
|
||||
self.validate_origin(req)?;
|
||||
self.validate_allowed_method(req)?;
|
||||
self.validate_allowed_headers(req)?;
|
||||
|
||||
Ok(Started::Response(
|
||||
HTTPOk.build()
|
||||
.if_some(self.max_age.as_ref(), |max_age, res| {
|
||||
let _ = res.header(
|
||||
.if_some(self.max_age.as_ref(), |max_age, resp| {
|
||||
let _ = resp.header(
|
||||
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
|
||||
.if_some(self.headers.as_ref(), |headers, res| {
|
||||
let _ = res.header(
|
||||
.if_some(self.headers.as_ref(), |headers, resp| {
|
||||
let _ = resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
headers.iter().fold(String::new(), |s, v| s + v.as_str()).as_str());})
|
||||
&headers.iter().fold(
|
||||
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]);})
|
||||
.if_true(self.origins.is_all(), |resp| {
|
||||
if self.send_wildcard {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
|
||||
} else {
|
||||
let origin = req.headers().get(header::ORIGIN).unwrap();
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
}
|
||||
})
|
||||
.if_true(self.origins.is_some(), |resp| {
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.origins_str.as_ref().unwrap().clone());
|
||||
})
|
||||
.if_true(self.supports_credentials, |resp| {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
|
||||
})
|
||||
.header(
|
||||
header::ACCESS_CONTROL_ALLOW_METHODS,
|
||||
self.methods.iter().fold(String::new(), |s, v| s + v.as_str()).as_str())
|
||||
&self.methods.iter().fold(
|
||||
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
|
||||
.finish()
|
||||
.unwrap()))
|
||||
} else {
|
||||
self.validate_origin(req)?;
|
||||
|
||||
Ok(Started::Done)
|
||||
}
|
||||
}
|
||||
|
||||
fn response(&self, _: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
|
||||
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
|
||||
match self.origins {
|
||||
AllOrSome::All => {
|
||||
if self.send_wildcard {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
|
||||
} else {
|
||||
let origin = req.headers().get(header::ORIGIN).unwrap();
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.origins_str.as_ref().unwrap().clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref expose) = self.expose_hdrs {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
HeaderValue::try_from(expose.as_str()).unwrap());
|
||||
}
|
||||
if self.supports_credentials {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
|
||||
}
|
||||
if self.vary_header {
|
||||
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
|
||||
let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
|
||||
val.extend(hdr.as_bytes());
|
||||
val.extend(b", Origin");
|
||||
HeaderValue::try_from(&val[..]).unwrap()
|
||||
} else {
|
||||
HeaderValue::from_static("Origin")
|
||||
};
|
||||
resp.headers_mut().insert(header::VARY, value);
|
||||
}
|
||||
Ok(Response::Done(resp))
|
||||
}
|
||||
}
|
||||
@ -293,6 +392,7 @@ pub struct CorsBuilder {
|
||||
cors: Option<Cors>,
|
||||
methods: bool,
|
||||
error: Option<http::Error>,
|
||||
expose_hdrs: HashSet<HeaderName>,
|
||||
}
|
||||
|
||||
fn cors<'a>(parts: &'a mut Option<Cors>, err: &Option<http::Error>) -> Option<&'a mut Cors> {
|
||||
@ -421,6 +521,30 @@ impl CorsBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
|
||||
/// This corresponds to the `Access-Control-Expose-Headers` responde header.
|
||||
///
|
||||
/// This is the `list of exposed headers` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
///
|
||||
/// This defaults to an empty set.
|
||||
pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder
|
||||
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
|
||||
{
|
||||
for h in headers {
|
||||
match HeaderName::try_from(h) {
|
||||
Ok(method) => {
|
||||
self.expose_hdrs.insert(method);
|
||||
},
|
||||
Err(e) => {
|
||||
self.error = Some(e.into());
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a maximum time for which this CORS request maybe cached.
|
||||
/// This value is set as the `Access-Control-Max-Age` header.
|
||||
///
|
||||
@ -449,13 +573,60 @@ impl CorsBuilder {
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub fn send_wildcard(&mut self) -> &mut CorsBuilder {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
cors.send_wildcards = true
|
||||
cors.send_wildcard = true
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
/// Allows users to make authenticated requests
|
||||
///
|
||||
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
|
||||
/// This allows cookies and credentials to be submitted across domains.
|
||||
///
|
||||
/// This option cannot be used in conjuction with an `allowed_origin` set to `All`
|
||||
/// and `send_wildcards` set to `true`.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
pub fn supports_credentials(&mut self) -> &mut CorsBuilder {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
cors.supports_credentials = true
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable `Vary` header support.
|
||||
///
|
||||
/// When enabled the header `Vary: Origin` will be returned as per the W3
|
||||
/// implementation guidelines.
|
||||
///
|
||||
/// Setting this header when the `Access-Control-Allow-Origin` is
|
||||
/// dynamically generated (e.g. when there is more than one allowed
|
||||
/// origin, and an Origin than '*' is returned) informs CDNs and other
|
||||
/// caches that the CORS headers are dynamic, and cannot be cached.
|
||||
///
|
||||
/// By default `vary` header support is enabled.
|
||||
pub fn disable_vary_header(&mut self) -> &mut CorsBuilder {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
cors.vary_header = false
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable *preflight* request support.
|
||||
///
|
||||
/// When enabled cors middleware automatically handles *OPTIONS* request.
|
||||
/// This is useful application level middleware.
|
||||
///
|
||||
/// By default *preflight* support is enabled.
|
||||
pub fn disable_preflight(&mut self) -> &mut CorsBuilder {
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
cors.preflight = false
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Finishes building and returns the built `Cors` instance.
|
||||
pub fn finish(&mut self) -> Result<Cors, http::Error> {
|
||||
pub fn finish(&mut self) -> Result<Cors, BuilderError> {
|
||||
if !self.methods {
|
||||
self.allowed_methods(vec![Method::GET, Method::HEAD,
|
||||
Method::POST, Method::OPTIONS, Method::PUT,
|
||||
@ -463,9 +634,114 @@ impl CorsBuilder {
|
||||
}
|
||||
|
||||
if let Some(e) = self.error.take() {
|
||||
return Err(e)
|
||||
return Err(BuilderError::ParseError(e))
|
||||
}
|
||||
|
||||
Ok(self.cors.take().expect("cannot reuse CorsBuilder"))
|
||||
let mut cors = self.cors.take().expect("cannot reuse CorsBuilder");
|
||||
|
||||
if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
|
||||
return Err(BuilderError::CredentialsWithWildcardOrigin)
|
||||
}
|
||||
|
||||
if let AllOrSome::Some(ref origins) = cors.origins {
|
||||
let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v));
|
||||
cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap());
|
||||
}
|
||||
|
||||
if !self.expose_hdrs.is_empty() {
|
||||
cors.expose_hdrs = Some(
|
||||
self.expose_hdrs.iter().fold(
|
||||
String::new(), |s, v| s + v.as_str())[1..].to_owned());
|
||||
}
|
||||
Ok(cors)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use test::TestRequest;
|
||||
|
||||
impl Started {
|
||||
fn is_done(&self) -> bool {
|
||||
match *self {
|
||||
Started::Done => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn response(self) -> HttpResponse {
|
||||
match self {
|
||||
Started::Response(resp) => resp,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||
fn cors_validates_illegal_allow_credentials() {
|
||||
Cors::build()
|
||||
.supports_credentials()
|
||||
.send_wildcard()
|
||||
.finish()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_origin_allows_all_origins() {
|
||||
let cors = Cors::default();
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com").finish();
|
||||
|
||||
assert!(cors.start(&mut req).ok().unwrap().is_done())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preflight() {
|
||||
let mut cors = Cors::build()
|
||||
.send_wildcard()
|
||||
.max_age(3600)
|
||||
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
|
||||
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
|
||||
.allowed_header(header::CONTENT_TYPE)
|
||||
.finish().unwrap();
|
||||
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
assert!(cors.start(&mut req).is_err());
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
assert!(cors.start(&mut req).is_err());
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
let resp = cors.start(&mut req).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"*"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
|
||||
assert_eq!(
|
||||
&b"3600"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes());
|
||||
//assert_eq!(
|
||||
// &b"authorization,accept,content-type"[..],
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes());
|
||||
//assert_eq!(
|
||||
// &b"POST,GET,OPTIONS"[..],
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
|
||||
|
||||
cors.preflight = false;
|
||||
assert!(cors.start(&mut req).unwrap().is_done());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user