mirror of
https://github.com/actix/actix-extras.git
synced 2025-06-25 18:09:22 +02:00
add rustfmt config
This commit is contained in:
@ -7,7 +7,8 @@
|
||||
//!
|
||||
//! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
|
||||
//! 2. Use any of the builder methods to set fields in the backend.
|
||||
//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend.
|
||||
//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
|
||||
//! constructed backend.
|
||||
//!
|
||||
//! Cors middleware could be used as parameter for `App::middleware()` or
|
||||
//! `ResourceHandler::middleware()` methods. But you have to use
|
||||
@ -40,65 +41,69 @@
|
||||
//! .register());
|
||||
//! }
|
||||
//! ```
|
||||
//! In this example custom *CORS* middleware get registered for "/index.html" endpoint.
|
||||
//! In this example custom *CORS* middleware get registered for "/index.html"
|
||||
//! endpoint.
|
||||
//!
|
||||
//! Cors middleware automatically handle *OPTIONS* preflight request.
|
||||
use std::collections::HashSet;
|
||||
use std::iter::FromIterator;
|
||||
use std::rc::Rc;
|
||||
|
||||
use http::{self, Method, HttpTryFrom, Uri, StatusCode};
|
||||
use http::header::{self, HeaderName, HeaderValue};
|
||||
use http::{self, HttpTryFrom, Method, StatusCode, Uri};
|
||||
|
||||
use application::App;
|
||||
use error::{Result, ResponseError};
|
||||
use resource::ResourceHandler;
|
||||
use error::{ResponseError, Result};
|
||||
use httpmessage::HttpMessage;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Middleware, Response, Started};
|
||||
use resource::ResourceHandler;
|
||||
|
||||
/// A set of errors that can occur during processing CORS
|
||||
#[derive(Debug, Fail)]
|
||||
pub enum CorsError {
|
||||
/// The HTTP request header `Origin` is required but was not provided
|
||||
#[fail(display="The HTTP request header `Origin` is required but was not provided")]
|
||||
#[fail(display = "The HTTP request header `Origin` is required but was not provided")]
|
||||
MissingOrigin,
|
||||
/// The HTTP request header `Origin` could not be parsed correctly.
|
||||
#[fail(display="The HTTP request header `Origin` could not be parsed correctly.")]
|
||||
#[fail(display = "The HTTP request header `Origin` could not be parsed correctly.")]
|
||||
BadOrigin,
|
||||
/// The request header `Access-Control-Request-Method` is required but is missing
|
||||
#[fail(display="The request header `Access-Control-Request-Method` is required but is missing")]
|
||||
/// The request header `Access-Control-Request-Method` is required but is
|
||||
/// missing
|
||||
#[fail(display = "The request header `Access-Control-Request-Method` is required but is missing")]
|
||||
MissingRequestMethod,
|
||||
/// The request header `Access-Control-Request-Method` has an invalid value
|
||||
#[fail(display="The request header `Access-Control-Request-Method` has an invalid value")]
|
||||
#[fail(display = "The request header `Access-Control-Request-Method` has an invalid value")]
|
||||
BadRequestMethod,
|
||||
/// The request header `Access-Control-Request-Headers` has an invalid value
|
||||
#[fail(display="The request header `Access-Control-Request-Headers` has an invalid value")]
|
||||
/// The request header `Access-Control-Request-Headers` has an invalid
|
||||
/// value
|
||||
#[fail(display = "The request header `Access-Control-Request-Headers` has an invalid value")]
|
||||
BadRequestHeaders,
|
||||
/// The request header `Access-Control-Request-Headers` is required but is missing.
|
||||
#[fail(display="The request header `Access-Control-Request-Headers` is required but is
|
||||
/// The request header `Access-Control-Request-Headers` is required but is
|
||||
/// missing.
|
||||
#[fail(display = "The request header `Access-Control-Request-Headers` is required but is
|
||||
missing")]
|
||||
MissingRequestHeaders,
|
||||
/// Origin is not allowed to make this request
|
||||
#[fail(display="Origin is not allowed to make this request")]
|
||||
#[fail(display = "Origin is not allowed to make this request")]
|
||||
OriginNotAllowed,
|
||||
/// Requested method is not allowed
|
||||
#[fail(display="Requested method is not allowed")]
|
||||
#[fail(display = "Requested method is not allowed")]
|
||||
MethodNotAllowed,
|
||||
/// One or more headers requested are not allowed
|
||||
#[fail(display="One or more headers requested are not allowed")]
|
||||
#[fail(display = "One or more headers requested are not allowed")]
|
||||
HeadersNotAllowed,
|
||||
}
|
||||
|
||||
impl ResponseError for CorsError {
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self))
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||
/// An enum signifying that some of type T is allowed, or `All` (everything is
|
||||
/// allowed).
|
||||
///
|
||||
/// `Default` is implemented for this enum and is `All`.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
@ -166,9 +171,16 @@ impl Default for Cors {
|
||||
origins: AllOrSome::default(),
|
||||
origins_str: None,
|
||||
methods: HashSet::from_iter(
|
||||
vec![Method::GET, Method::HEAD,
|
||||
Method::POST, Method::OPTIONS, Method::PUT,
|
||||
Method::PATCH, Method::DELETE].into_iter()),
|
||||
vec![
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::POST,
|
||||
Method::OPTIONS,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
].into_iter(),
|
||||
),
|
||||
headers: AllOrSome::All,
|
||||
expose_hdrs: None,
|
||||
max_age: None,
|
||||
@ -177,7 +189,9 @@ impl Default for Cors {
|
||||
supports_credentials: false,
|
||||
vary_header: true,
|
||||
};
|
||||
Cors{inner: Rc::new(inner)}
|
||||
Cors {
|
||||
inner: Rc::new(inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,11 +261,13 @@ impl Cors {
|
||||
/// This method register cors middleware with resource and
|
||||
/// adds route for *OPTIONS* preflight requests.
|
||||
///
|
||||
/// It is possible to register *Cors* middleware with `ResourceHandler::middleware()`
|
||||
/// method, but in that case *Cors* middleware wont be able to handle *OPTIONS*
|
||||
/// requests.
|
||||
/// It is possible to register *Cors* middleware with
|
||||
/// `ResourceHandler::middleware()` method, but in that case *Cors*
|
||||
/// middleware wont be able to handle *OPTIONS* requests.
|
||||
pub fn register<S: 'static>(self, resource: &mut ResourceHandler<S>) {
|
||||
resource.method(Method::OPTIONS).h(|_| HttpResponse::Ok());
|
||||
resource
|
||||
.method(Method::OPTIONS)
|
||||
.h(|_| HttpResponse::Ok());
|
||||
resource.middleware(self);
|
||||
}
|
||||
|
||||
@ -260,28 +276,32 @@ impl Cors {
|
||||
if let Ok(origin) = hdr.to_str() {
|
||||
return match self.inner.origins {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
allowed_origins
|
||||
.get(origin)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| CorsError::OriginNotAllowed)
|
||||
}
|
||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||
.get(origin)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| CorsError::OriginNotAllowed),
|
||||
};
|
||||
}
|
||||
Err(CorsError::BadOrigin)
|
||||
} else {
|
||||
return match self.inner.origins {
|
||||
AllOrSome::All => Ok(()),
|
||||
_ => Err(CorsError::MissingOrigin)
|
||||
}
|
||||
_ => Err(CorsError::MissingOrigin),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
|
||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
|
||||
fn validate_allowed_method<S>(
|
||||
&self, req: &mut HttpRequest<S>
|
||||
) -> Result<(), CorsError> {
|
||||
if let Some(hdr) = req.headers()
|
||||
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
|
||||
{
|
||||
if let Ok(meth) = hdr.to_str() {
|
||||
if let Ok(method) = Method::try_from(meth) {
|
||||
return self.inner.methods.get(&method)
|
||||
return self.inner
|
||||
.methods
|
||||
.get(&method)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| CorsError::MethodNotAllowed);
|
||||
}
|
||||
@ -292,24 +312,28 @@ impl Cors {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
|
||||
fn validate_allowed_headers<S>(
|
||||
&self, req: &mut HttpRequest<S>
|
||||
) -> Result<(), CorsError> {
|
||||
match self.inner.headers {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_headers) => {
|
||||
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
if let Some(hdr) = req.headers()
|
||||
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
||||
{
|
||||
if let Ok(headers) = hdr.to_str() {
|
||||
let mut hdrs = HashSet::new();
|
||||
for hdr in headers.split(',') {
|
||||
match HeaderName::try_from(hdr.trim()) {
|
||||
Ok(hdr) => hdrs.insert(hdr),
|
||||
Err(_) => return Err(CorsError::BadRequestHeaders)
|
||||
Err(_) => return Err(CorsError::BadRequestHeaders),
|
||||
};
|
||||
}
|
||||
|
||||
if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) {
|
||||
return Err(CorsError::HeadersNotAllowed)
|
||||
return Err(CorsError::HeadersNotAllowed);
|
||||
}
|
||||
return Ok(())
|
||||
return Ok(());
|
||||
}
|
||||
Err(CorsError::BadRequestHeaders)
|
||||
} else {
|
||||
@ -321,7 +345,6 @@ impl Cors {
|
||||
}
|
||||
|
||||
impl<S> Middleware<S> for Cors {
|
||||
|
||||
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||
if self.inner.preflight && Method::OPTIONS == *req.method() {
|
||||
self.validate_origin(req)?;
|
||||
@ -330,9 +353,17 @@ impl<S> Middleware<S> for Cors {
|
||||
|
||||
// allowed headers
|
||||
let headers = if let Some(headers) = self.inner.headers.as_ref() {
|
||||
Some(HeaderValue::try_from(&headers.iter().fold(
|
||||
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap())
|
||||
} else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
Some(
|
||||
HeaderValue::try_from(
|
||||
&headers
|
||||
.iter()
|
||||
.fold(String::new(), |s, v| s + "," + v.as_str())
|
||||
.as_str()[1..],
|
||||
).unwrap(),
|
||||
)
|
||||
} else if let Some(hdr) = req.headers()
|
||||
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
|
||||
{
|
||||
Some(hdr.clone())
|
||||
} else {
|
||||
None
|
||||
@ -342,31 +373,44 @@ impl<S> Middleware<S> for Cors {
|
||||
HttpResponse::Ok()
|
||||
.if_some(self.inner.max_age.as_ref(), |max_age, resp| {
|
||||
let _ = resp.header(
|
||||
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
|
||||
header::ACCESS_CONTROL_MAX_AGE,
|
||||
format!("{}", max_age).as_str(),
|
||||
);
|
||||
})
|
||||
.if_some(headers, |headers, resp| {
|
||||
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); })
|
||||
let _ =
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
|
||||
})
|
||||
.if_true(self.inner.origins.is_all(), |resp| {
|
||||
if self.inner.send_wildcard {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
|
||||
} else {
|
||||
let origin = req.headers().get(header::ORIGIN).unwrap();
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
origin.clone(),
|
||||
);
|
||||
}
|
||||
})
|
||||
.if_true(self.inner.origins.is_some(), |resp| {
|
||||
resp.header(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.inner.origins_str.as_ref().unwrap().clone());
|
||||
self.inner.origins_str.as_ref().unwrap().clone(),
|
||||
);
|
||||
})
|
||||
.if_true(self.inner.supports_credentials, |resp| {
|
||||
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
|
||||
})
|
||||
.header(
|
||||
header::ACCESS_CONTROL_ALLOW_METHODS,
|
||||
&self.inner.methods.iter().fold(
|
||||
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
|
||||
.finish()))
|
||||
&self.inner
|
||||
.methods
|
||||
.iter()
|
||||
.fold(String::new(), |s, v| s + "," + v.as_str())
|
||||
.as_str()[1..],
|
||||
)
|
||||
.finish(),
|
||||
))
|
||||
} else {
|
||||
self.validate_origin(req)?;
|
||||
|
||||
@ -374,32 +418,40 @@ impl<S> Middleware<S> for Cors {
|
||||
}
|
||||
}
|
||||
|
||||
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
|
||||
fn response(
|
||||
&self, req: &mut HttpRequest<S>, mut resp: HttpResponse
|
||||
) -> Result<Response> {
|
||||
match self.inner.origins {
|
||||
AllOrSome::All => {
|
||||
if self.inner.send_wildcard {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
resp.headers_mut()
|
||||
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.inner.origins_str.as_ref().unwrap().clone());
|
||||
self.inner.origins_str.as_ref().unwrap().clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref expose) = self.inner.expose_hdrs {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
HeaderValue::try_from(expose.as_str()).unwrap());
|
||||
HeaderValue::try_from(expose.as_str()).unwrap(),
|
||||
);
|
||||
}
|
||||
if self.inner.supports_credentials {
|
||||
resp.headers_mut().insert(
|
||||
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
|
||||
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
||||
HeaderValue::from_static("true"),
|
||||
);
|
||||
}
|
||||
if self.inner.vary_header {
|
||||
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
|
||||
@ -416,13 +468,15 @@ impl<S> Middleware<S> for Cors {
|
||||
}
|
||||
}
|
||||
|
||||
/// Structure that follows the builder pattern for building `Cors` middleware structs.
|
||||
/// Structure that follows the builder pattern for building `Cors` middleware
|
||||
/// structs.
|
||||
///
|
||||
/// To construct a cors:
|
||||
///
|
||||
/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
|
||||
/// 2. Use any of the builder methods to set fields in the backend.
|
||||
/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend.
|
||||
/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
|
||||
/// constructed backend.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -442,7 +496,7 @@ impl<S> Middleware<S> for Cors {
|
||||
/// .finish();
|
||||
/// # }
|
||||
/// ```
|
||||
pub struct CorsBuilder<S=()> {
|
||||
pub struct CorsBuilder<S = ()> {
|
||||
cors: Option<Inner>,
|
||||
methods: bool,
|
||||
error: Option<http::Error>,
|
||||
@ -451,26 +505,26 @@ pub struct CorsBuilder<S=()> {
|
||||
app: Option<App<S>>,
|
||||
}
|
||||
|
||||
fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<http::Error>)
|
||||
-> Option<&'a mut Inner>
|
||||
{
|
||||
fn cors<'a>(
|
||||
parts: &'a mut Option<Inner>, err: &Option<http::Error>
|
||||
) -> Option<&'a mut Inner> {
|
||||
if err.is_some() {
|
||||
return None
|
||||
return None;
|
||||
}
|
||||
parts.as_mut()
|
||||
}
|
||||
|
||||
impl<S: 'static> CorsBuilder<S> {
|
||||
|
||||
/// Add an origin that are allowed to make requests.
|
||||
/// Will be verified against the `Origin` request header.
|
||||
///
|
||||
/// When `All` is set, and `send_wildcard` is set, "*" will be sent in
|
||||
/// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request
|
||||
/// header will be echoed back in the `Access-Control-Allow-Origin` response header.
|
||||
/// the `Access-Control-Allow-Origin` response header. Otherwise, the
|
||||
/// client's `Origin` request header will be echoed back in the
|
||||
/// `Access-Control-Allow-Origin` response header.
|
||||
///
|
||||
/// When `Some` is set, the client's `Origin` request header will be checked in a
|
||||
/// case-sensitive manner.
|
||||
/// When `Some` is set, the client's `Origin` request header will be
|
||||
/// checked in a case-sensitive manner.
|
||||
///
|
||||
/// This is the `list of origins` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
@ -497,15 +551,17 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a list of methods which the allowed origins are allowed to access for
|
||||
/// requests.
|
||||
/// Set a list of methods which the allowed origins are allowed to access
|
||||
/// for requests.
|
||||
///
|
||||
/// This is the `list of methods` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
///
|
||||
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
|
||||
pub fn allowed_methods<U, M>(&mut self, methods: U) -> &mut CorsBuilder<S>
|
||||
where U: IntoIterator<Item=M>, Method: HttpTryFrom<M>
|
||||
where
|
||||
U: IntoIterator<Item = M>,
|
||||
Method: HttpTryFrom<M>,
|
||||
{
|
||||
self.methods = true;
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
@ -513,20 +569,21 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
match Method::try_from(m) {
|
||||
Ok(method) => {
|
||||
cors.methods.insert(method);
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
self.error = Some(e.into());
|
||||
break
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an allowed header
|
||||
pub fn allowed_header<H>(&mut self, header: H) -> &mut CorsBuilder<S>
|
||||
where HeaderName: HttpTryFrom<H>
|
||||
where
|
||||
HeaderName: HttpTryFrom<H>,
|
||||
{
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
match HeaderName::try_from(header) {
|
||||
@ -547,15 +604,18 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
/// Set a list of header field names which can be used when
|
||||
/// this resource is accessed by allowed origins.
|
||||
///
|
||||
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
|
||||
/// will be echoed back in the `Access-Control-Allow-Headers` header.
|
||||
/// If `All` is set, whatever is requested by the client in
|
||||
/// `Access-Control-Request-Headers` will be echoed back in the
|
||||
/// `Access-Control-Allow-Headers` header.
|
||||
///
|
||||
/// This is the `list of headers` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
///
|
||||
/// Defaults to `All`.
|
||||
pub fn allowed_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
|
||||
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
|
||||
where
|
||||
U: IntoIterator<Item = H>,
|
||||
HeaderName: HttpTryFrom<H>,
|
||||
{
|
||||
if let Some(cors) = cors(&mut self.cors, &self.error) {
|
||||
for h in headers {
|
||||
@ -570,32 +630,35 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
}
|
||||
Err(e) => {
|
||||
self.error = Some(e.into());
|
||||
break
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
|
||||
/// This corresponds to the `Access-Control-Expose-Headers` response header.
|
||||
/// Set a list of headers which are safe to expose to the API of a CORS API
|
||||
/// specification. This corresponds to the
|
||||
/// `Access-Control-Expose-Headers` response header.
|
||||
///
|
||||
/// This is the `list of exposed headers` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
///
|
||||
/// This defaults to an empty set.
|
||||
pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
|
||||
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
|
||||
where
|
||||
U: IntoIterator<Item = H>,
|
||||
HeaderName: HttpTryFrom<H>,
|
||||
{
|
||||
for h in headers {
|
||||
match HeaderName::try_from(h) {
|
||||
Ok(method) => {
|
||||
self.expose_hdrs.insert(method);
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
self.error = Some(e.into());
|
||||
break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -615,16 +678,17 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
|
||||
/// Set a wildcard origins
|
||||
///
|
||||
/// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
|
||||
/// `Access-Control-Allow-Origin` response header is sent, rather than the request’s
|
||||
/// `Origin` header.
|
||||
/// If send wildcard is set and the `allowed_origins` parameter is `All`, a
|
||||
/// wildcard `Access-Control-Allow-Origin` response header is sent,
|
||||
/// rather than the request’s `Origin` header.
|
||||
///
|
||||
/// This is the `supports credentials flag` in the
|
||||
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
|
||||
///
|
||||
/// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and
|
||||
/// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result
|
||||
/// in an `Error::CredentialsWithWildcardOrigin` error during actix launch or runtime.
|
||||
/// This **CANNOT** be used in conjunction with `allowed_origins` set to
|
||||
/// `All` and `allow_credentials` set to `true`. Depending on the mode
|
||||
/// of usage, this will either result in an `Error::
|
||||
/// CredentialsWithWildcardOrigin` error during actix launch or runtime.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
pub fn send_wildcard(&mut self) -> &mut CorsBuilder<S> {
|
||||
@ -636,11 +700,12 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
|
||||
/// Allows users to make authenticated requests
|
||||
///
|
||||
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
|
||||
/// This allows cookies and credentials to be submitted across domains.
|
||||
/// If true, injects the `Access-Control-Allow-Credentials` header in
|
||||
/// responses. This allows cookies and credentials to be submitted
|
||||
/// across domains.
|
||||
///
|
||||
/// This option cannot be used in conjunction with an `allowed_origin` set to `All`
|
||||
/// and `send_wildcards` set to `true`.
|
||||
/// This option cannot be used in conjunction with an `allowed_origin` set
|
||||
/// to `All` and `send_wildcards` set to `true`.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
///
|
||||
@ -713,7 +778,8 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
/// }
|
||||
/// ```
|
||||
pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut CorsBuilder<S>
|
||||
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static
|
||||
where
|
||||
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
|
||||
{
|
||||
// add resource handler
|
||||
let mut handler = ResourceHandler::default();
|
||||
@ -725,9 +791,15 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
|
||||
fn construct(&mut self) -> Cors {
|
||||
if !self.methods {
|
||||
self.allowed_methods(vec![Method::GET, Method::HEAD,
|
||||
Method::POST, Method::OPTIONS, Method::PUT,
|
||||
Method::PATCH, Method::DELETE]);
|
||||
self.allowed_methods(vec![
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::POST,
|
||||
Method::OPTIONS,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
]);
|
||||
}
|
||||
|
||||
if let Some(e) = self.error.take() {
|
||||
@ -741,16 +813,23 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
}
|
||||
|
||||
if let AllOrSome::Some(ref origins) = cors.origins {
|
||||
let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v));
|
||||
let s = origins
|
||||
.iter()
|
||||
.fold(String::new(), |s, v| s + &format!("{}", v));
|
||||
cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap());
|
||||
}
|
||||
|
||||
if !self.expose_hdrs.is_empty() {
|
||||
cors.expose_hdrs = Some(
|
||||
self.expose_hdrs.iter().fold(
|
||||
String::new(), |s, v| s + v.as_str())[1..].to_owned());
|
||||
self.expose_hdrs
|
||||
.iter()
|
||||
.fold(String::new(), |s, v| s + v.as_str())[1..]
|
||||
.to_owned(),
|
||||
);
|
||||
}
|
||||
Cors {
|
||||
inner: Rc::new(cors),
|
||||
}
|
||||
Cors{inner: Rc::new(cors)}
|
||||
}
|
||||
|
||||
/// Finishes building and returns the built `Cors` instance.
|
||||
@ -758,13 +837,16 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
/// This method panics in case of any configuration error.
|
||||
pub fn finish(&mut self) -> Cors {
|
||||
if !self.resources.is_empty() {
|
||||
panic!("CorsBuilder::resource() was used,
|
||||
to construct CORS `.register(app)` method should be used");
|
||||
panic!(
|
||||
"CorsBuilder::resource() was used,
|
||||
to construct CORS `.register(app)` method should be used"
|
||||
);
|
||||
}
|
||||
self.construct()
|
||||
}
|
||||
|
||||
/// Finishes building Cors middleware and register middleware for application
|
||||
/// Finishes building Cors middleware and register middleware for
|
||||
/// application
|
||||
///
|
||||
/// This method panics in case of any configuration error or if non of
|
||||
/// resources are registered.
|
||||
@ -774,8 +856,9 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
}
|
||||
|
||||
let cors = self.construct();
|
||||
let mut app = self.app.take().expect(
|
||||
"CorsBuilder has to be constructed with Cors::for_app(app)");
|
||||
let mut app = self.app
|
||||
.take()
|
||||
.expect("CorsBuilder has to be constructed with Cors::for_app(app)");
|
||||
|
||||
// register resources
|
||||
for (path, mut resource) in self.resources.drain(..) {
|
||||
@ -787,7 +870,6 @@ impl<S: 'static> CorsBuilder<S> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -845,8 +927,8 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_origin_allows_all_origins() {
|
||||
let cors = Cors::default();
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com").finish();
|
||||
let mut req =
|
||||
TestRequest::with_header("Origin", "https://www.example.com").finish();
|
||||
|
||||
assert!(cors.start(&mut req).ok().unwrap().is_done())
|
||||
}
|
||||
@ -861,8 +943,7 @@ mod tests {
|
||||
.allowed_header(header::CONTENT_TYPE)
|
||||
.finish();
|
||||
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com")
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
@ -877,23 +958,35 @@ mod tests {
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
|
||||
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
|
||||
.header(
|
||||
header::ACCESS_CONTROL_REQUEST_HEADERS,
|
||||
"AUTHORIZATION,ACCEPT",
|
||||
)
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
let resp = cors.start(&mut req).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"*"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
assert_eq!(
|
||||
&b"3600"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes());
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_MAX_AGE)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
//assert_eq!(
|
||||
// &b"authorization,accept,content-type"[..],
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes());
|
||||
//assert_eq!(
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().
|
||||
// as_bytes()); assert_eq!(
|
||||
// &b"POST,GET,OPTIONS"[..],
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
|
||||
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().
|
||||
// as_bytes());
|
||||
|
||||
Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
|
||||
assert!(cors.start(&mut req).unwrap().is_done());
|
||||
@ -903,7 +996,8 @@ mod tests {
|
||||
#[should_panic(expected = "MissingOrigin")]
|
||||
fn test_validate_missing_origin() {
|
||||
let cors = Cors::build()
|
||||
.allowed_origin("https://www.example.com").finish();
|
||||
.allowed_origin("https://www.example.com")
|
||||
.finish();
|
||||
|
||||
let mut req = HttpRequest::default();
|
||||
cors.start(&mut req).unwrap();
|
||||
@ -913,7 +1007,8 @@ mod tests {
|
||||
#[should_panic(expected = "OriginNotAllowed")]
|
||||
fn test_validate_not_allowed_origin() {
|
||||
let cors = Cors::build()
|
||||
.allowed_origin("https://www.example.com").finish();
|
||||
.allowed_origin("https://www.example.com")
|
||||
.finish();
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.unknown.com")
|
||||
.method(Method::GET)
|
||||
@ -924,7 +1019,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_validate_origin() {
|
||||
let cors = Cors::build()
|
||||
.allowed_origin("https://www.example.com").finish();
|
||||
.allowed_origin("https://www.example.com")
|
||||
.finish();
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::GET)
|
||||
@ -940,16 +1036,23 @@ mod tests {
|
||||
let mut req = TestRequest::default().method(Method::GET).finish();
|
||||
let resp: HttpResponse = HttpResponse::Ok().into();
|
||||
let resp = cors.response(&mut req, resp).unwrap().response();
|
||||
assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
|
||||
assert!(
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.is_none()
|
||||
);
|
||||
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com")
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
let resp = cors.response(&mut req, resp).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"https://www.example.com"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -963,8 +1066,7 @@ mod tests {
|
||||
.allowed_header(header::CONTENT_TYPE)
|
||||
.finish();
|
||||
|
||||
let mut req = TestRequest::with_header(
|
||||
"Origin", "https://www.example.com")
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
|
||||
.method(Method::OPTIONS)
|
||||
.finish();
|
||||
|
||||
@ -972,10 +1074,15 @@ mod tests {
|
||||
let resp = cors.response(&mut req, resp).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"*"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
assert_eq!(
|
||||
&b"Origin"[..],
|
||||
resp.headers().get(header::VARY).unwrap().as_bytes());
|
||||
resp.headers().get(header::VARY).unwrap().as_bytes()
|
||||
);
|
||||
|
||||
let resp: HttpResponse = HttpResponse::Ok()
|
||||
.header(header::VARY, "Accept")
|
||||
@ -983,7 +1090,8 @@ mod tests {
|
||||
let resp = cors.response(&mut req, resp).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"Accept, Origin"[..],
|
||||
resp.headers().get(header::VARY).unwrap().as_bytes());
|
||||
resp.headers().get(header::VARY).unwrap().as_bytes()
|
||||
);
|
||||
|
||||
let cors = Cors::build()
|
||||
.disable_vary_header()
|
||||
@ -993,25 +1101,33 @@ mod tests {
|
||||
let resp = cors.response(&mut req, resp).unwrap().response();
|
||||
assert_eq!(
|
||||
&b"https://www.example.com"[..],
|
||||
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
|
||||
resp.headers()
|
||||
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
.unwrap()
|
||||
.as_bytes()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cors_resource() {
|
||||
let mut srv = test::TestServer::with_factory(
|
||||
|| App::new()
|
||||
.configure(
|
||||
|app| Cors::for_app(app)
|
||||
.allowed_origin("https://www.example.com")
|
||||
.resource("/test", |r| r.f(|_| HttpResponse::Ok()))
|
||||
.register()));
|
||||
let mut srv = test::TestServer::with_factory(|| {
|
||||
App::new().configure(|app| {
|
||||
Cors::for_app(app)
|
||||
.allowed_origin("https://www.example.com")
|
||||
.resource("/test", |r| r.f(|_| HttpResponse::Ok()))
|
||||
.register()
|
||||
})
|
||||
});
|
||||
|
||||
let request = srv.get().uri(srv.url("/test")).finish().unwrap();
|
||||
let response = srv.execute(request.send()).unwrap();
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let request = srv.get().uri(srv.url("/test"))
|
||||
.header("ORIGIN", "https://www.example.com").finish().unwrap();
|
||||
let request = srv.get()
|
||||
.uri(srv.url("/test"))
|
||||
.header("ORIGIN", "https://www.example.com")
|
||||
.finish()
|
||||
.unwrap();
|
||||
let response = srv.execute(request.send()).unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
@ -48,8 +48,8 @@ use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use bytes::Bytes;
|
||||
use error::{Result, ResponseError};
|
||||
use http::{HeaderMap, HttpTryFrom, Uri, header};
|
||||
use error::{ResponseError, Result};
|
||||
use http::{header, HeaderMap, HttpTryFrom, Uri};
|
||||
use httpmessage::HttpMessage;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
@ -59,13 +59,13 @@ use middleware::{Middleware, Started};
|
||||
#[derive(Debug, Fail)]
|
||||
pub enum CsrfError {
|
||||
/// The HTTP request header `Origin` was required but not provided.
|
||||
#[fail(display="Origin header required")]
|
||||
#[fail(display = "Origin header required")]
|
||||
MissingOrigin,
|
||||
/// The HTTP request header `Origin` could not be parsed correctly.
|
||||
#[fail(display="Could not parse Origin header")]
|
||||
#[fail(display = "Could not parse Origin header")]
|
||||
BadOrigin,
|
||||
/// The cross-site request was denied.
|
||||
#[fail(display="Cross-site request denied")]
|
||||
#[fail(display = "Cross-site request denied")]
|
||||
CsrDenied,
|
||||
}
|
||||
|
||||
@ -80,15 +80,14 @@ fn uri_origin(uri: &Uri) -> Option<String> {
|
||||
(Some(scheme), Some(host), Some(port)) => {
|
||||
Some(format!("{}://{}:{}", scheme, host, port))
|
||||
}
|
||||
(Some(scheme), Some(host), None) => {
|
||||
Some(format!("{}://{}", scheme, host))
|
||||
}
|
||||
_ => None
|
||||
(Some(scheme), Some(host), None) => Some(format!("{}://{}", scheme, host)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
|
||||
headers.get(header::ORIGIN)
|
||||
headers
|
||||
.get(header::ORIGIN)
|
||||
.map(|origin| {
|
||||
origin
|
||||
.to_str()
|
||||
@ -96,15 +95,14 @@ fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
|
||||
.map(|o| o.into())
|
||||
})
|
||||
.or_else(|| {
|
||||
headers.get(header::REFERER)
|
||||
.map(|referer| {
|
||||
Uri::try_from(Bytes::from(referer.as_bytes()))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(uri_origin)
|
||||
.ok_or(CsrfError::BadOrigin)
|
||||
.map(|o| o.into())
|
||||
})
|
||||
headers.get(header::REFERER).map(|referer| {
|
||||
Uri::try_from(Bytes::from(referer.as_bytes()))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(uri_origin)
|
||||
.ok_or(CsrfError::BadOrigin)
|
||||
.map(|o| o.into())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -194,7 +192,8 @@ impl CsrfFilter {
|
||||
let is_upgrade = req.headers().contains_key(header::UPGRADE);
|
||||
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
|
||||
|
||||
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
|
||||
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with"))
|
||||
{
|
||||
Ok(())
|
||||
} else if let Some(header) = origin(req.headers()) {
|
||||
match header {
|
||||
@ -225,8 +224,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_safe() {
|
||||
let csrf = CsrfFilter::new()
|
||||
.allowed_origin("https://www.example.com");
|
||||
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
|
||||
.method(Method::HEAD)
|
||||
@ -237,8 +235,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_csrf() {
|
||||
let csrf = CsrfFilter::new()
|
||||
.allowed_origin("https://www.example.com");
|
||||
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
|
||||
|
||||
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
|
||||
.method(Method::POST)
|
||||
@ -249,11 +246,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_referer() {
|
||||
let csrf = CsrfFilter::new()
|
||||
.allowed_origin("https://www.example.com");
|
||||
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
|
||||
|
||||
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param")
|
||||
.method(Method::POST)
|
||||
let mut req = TestRequest::with_header(
|
||||
"Referer",
|
||||
"https://www.example.com/some/path?query=param",
|
||||
).method(Method::POST)
|
||||
.finish();
|
||||
|
||||
assert!(csrf.start(&mut req).is_ok());
|
||||
@ -261,8 +259,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_upgrade() {
|
||||
let strict_csrf = CsrfFilter::new()
|
||||
.allowed_origin("https://www.example.com");
|
||||
let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
|
||||
|
||||
let lax_csrf = CsrfFilter::new()
|
||||
.allowed_origin("https://www.example.com")
|
||||
|
@ -1,11 +1,11 @@
|
||||
//! Default response headers
|
||||
use http::{HeaderMap, HttpTryFrom};
|
||||
use http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
|
||||
use http::{HeaderMap, HttpTryFrom};
|
||||
|
||||
use error::Result;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Response, Middleware};
|
||||
use middleware::{Middleware, Response};
|
||||
|
||||
/// `Middleware` for setting default response headers.
|
||||
///
|
||||
@ -27,14 +27,17 @@ use middleware::{Response, Middleware};
|
||||
/// .finish();
|
||||
/// }
|
||||
/// ```
|
||||
pub struct DefaultHeaders{
|
||||
pub struct DefaultHeaders {
|
||||
ct: bool,
|
||||
headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl Default for DefaultHeaders {
|
||||
fn default() -> Self {
|
||||
DefaultHeaders{ct: false, headers: HeaderMap::new()}
|
||||
DefaultHeaders {
|
||||
ct: false,
|
||||
headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,15 +51,16 @@ impl DefaultHeaders {
|
||||
#[inline]
|
||||
#[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))]
|
||||
pub fn header<K, V>(mut self, key: K, value: V) -> Self
|
||||
where HeaderName: HttpTryFrom<K>,
|
||||
HeaderValue: HttpTryFrom<V>
|
||||
where
|
||||
HeaderName: HttpTryFrom<K>,
|
||||
HeaderValue: HttpTryFrom<V>,
|
||||
{
|
||||
match HeaderName::try_from(key) {
|
||||
Ok(key) => {
|
||||
match HeaderValue::try_from(value) {
|
||||
Ok(value) => { self.headers.append(key, value); }
|
||||
Err(_) => panic!("Can not create header value"),
|
||||
Ok(key) => match HeaderValue::try_from(value) {
|
||||
Ok(value) => {
|
||||
self.headers.append(key, value);
|
||||
}
|
||||
Err(_) => panic!("Can not create header value"),
|
||||
},
|
||||
Err(_) => panic!("Can not create header name"),
|
||||
}
|
||||
@ -71,8 +75,9 @@ impl DefaultHeaders {
|
||||
}
|
||||
|
||||
impl<S> Middleware<S> for DefaultHeaders {
|
||||
|
||||
fn response(&self, _: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
|
||||
fn response(
|
||||
&self, _: &mut HttpRequest<S>, mut resp: HttpResponse
|
||||
) -> Result<Response> {
|
||||
for (key, value) in self.headers.iter() {
|
||||
if !resp.headers().contains_key(key) {
|
||||
resp.headers_mut().insert(key, value.clone());
|
||||
@ -81,7 +86,9 @@ impl<S> Middleware<S> for DefaultHeaders {
|
||||
// default content-type
|
||||
if self.ct && !resp.headers().contains_key(CONTENT_TYPE) {
|
||||
resp.headers_mut().insert(
|
||||
CONTENT_TYPE, HeaderValue::from_static("application/octet-stream"));
|
||||
CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
}
|
||||
Ok(Response::Done(resp))
|
||||
}
|
||||
@ -94,8 +101,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_default_headers() {
|
||||
let mw = DefaultHeaders::new()
|
||||
.header(CONTENT_TYPE, "0001");
|
||||
let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001");
|
||||
|
||||
let mut req = HttpRequest::default();
|
||||
|
||||
@ -106,7 +112,9 @@ mod tests {
|
||||
};
|
||||
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
|
||||
|
||||
let resp = HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish();
|
||||
let resp = HttpResponse::Ok()
|
||||
.header(CONTENT_TYPE, "0002")
|
||||
.finish();
|
||||
let resp = match mw.response(&mut req, resp) {
|
||||
Ok(Response::Done(resp)) => resp,
|
||||
_ => panic!(),
|
||||
|
@ -6,14 +6,13 @@ use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Middleware, Response};
|
||||
|
||||
|
||||
type ErrorHandler<S> = Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response>;
|
||||
|
||||
/// `Middleware` for allowing custom handlers for responses.
|
||||
///
|
||||
/// You can use `ErrorHandlers::handler()` method to register a custom error handler
|
||||
/// for specific status code. You can modify existing response or create completly new
|
||||
/// one.
|
||||
/// You can use `ErrorHandlers::handler()` method to register a custom error
|
||||
/// handler for specific status code. You can modify existing response or
|
||||
/// create completly new one.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
@ -53,7 +52,6 @@ impl<S> Default for ErrorHandlers<S> {
|
||||
}
|
||||
|
||||
impl<S> ErrorHandlers<S> {
|
||||
|
||||
/// Construct new `ErrorHandlers` instance
|
||||
pub fn new() -> Self {
|
||||
ErrorHandlers::default()
|
||||
@ -61,7 +59,8 @@ impl<S> ErrorHandlers<S> {
|
||||
|
||||
/// Register error handler for specified status code
|
||||
pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
|
||||
where F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static
|
||||
where
|
||||
F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static,
|
||||
{
|
||||
self.handlers.insert(status, Box::new(handler));
|
||||
self
|
||||
@ -69,8 +68,9 @@ impl<S> ErrorHandlers<S> {
|
||||
}
|
||||
|
||||
impl<S: 'static> Middleware<S> for ErrorHandlers<S> {
|
||||
|
||||
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
|
||||
fn response(
|
||||
&self, req: &mut HttpRequest<S>, resp: HttpResponse
|
||||
) -> Result<Response> {
|
||||
if let Some(handler) = self.handlers.get(&resp.status()) {
|
||||
handler(req, resp)
|
||||
} else {
|
||||
@ -90,11 +90,11 @@ mod tests {
|
||||
builder.header(CONTENT_TYPE, "0001");
|
||||
Ok(Response::Done(builder.into()))
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_handler() {
|
||||
let mw = ErrorHandlers::new()
|
||||
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
|
||||
let mw =
|
||||
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
|
||||
|
||||
let mut req = HttpRequest::default();
|
||||
let resp = HttpResponse::InternalServerError().finish();
|
||||
|
@ -4,14 +4,14 @@ use std::fmt;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use libc;
|
||||
use time;
|
||||
use regex::Regex;
|
||||
use time;
|
||||
|
||||
use error::Result;
|
||||
use httpmessage::HttpMessage;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Middleware, Started, Finished};
|
||||
use middleware::{Finished, Middleware, Started};
|
||||
|
||||
/// `Middleware` for logging request and response info to the terminal.
|
||||
/// `Logger` middleware uses standard log crate to log information. You should
|
||||
@ -21,7 +21,8 @@ use middleware::{Middleware, Started, Finished};
|
||||
/// ## Usage
|
||||
///
|
||||
/// Create `Logger` middleware with the specified `format`.
|
||||
/// Default `Logger` could be created with `default` method, it uses the default format:
|
||||
/// Default `Logger` could be created with `default` method, it uses the
|
||||
/// default format:
|
||||
///
|
||||
/// ```ignore
|
||||
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
|
||||
@ -59,7 +60,8 @@ use middleware::{Middleware, Started, Finished};
|
||||
///
|
||||
/// `%b` Size of response in bytes, including HTTP headers
|
||||
///
|
||||
/// `%T` Time taken to serve the request, in seconds with floating fraction in .06f format
|
||||
/// `%T` Time taken to serve the request, in seconds with floating fraction in
|
||||
/// .06f format
|
||||
///
|
||||
/// `%D` Time taken to serve the request, in milliseconds
|
||||
///
|
||||
@ -76,7 +78,9 @@ pub struct Logger {
|
||||
impl Logger {
|
||||
/// Create `Logger` middleware with the specified `format`.
|
||||
pub fn new(format: &str) -> Logger {
|
||||
Logger { format: Format::new(format) }
|
||||
Logger {
|
||||
format: Format::new(format),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,14 +91,15 @@ impl Default for Logger {
|
||||
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
|
||||
/// ```
|
||||
fn default() -> Logger {
|
||||
Logger { format: Format::default() }
|
||||
Logger {
|
||||
format: Format::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct StartTime(time::Tm);
|
||||
|
||||
impl Logger {
|
||||
|
||||
fn log<S>(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) {
|
||||
let entry_time = req.extensions().get::<StartTime>().unwrap().0;
|
||||
|
||||
@ -109,7 +114,6 @@ impl Logger {
|
||||
}
|
||||
|
||||
impl<S> Middleware<S> for Logger {
|
||||
|
||||
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||
req.extensions().insert(StartTime(time::now()));
|
||||
Ok(Started::Done)
|
||||
@ -153,29 +157,26 @@ impl Format {
|
||||
idx = m.end();
|
||||
|
||||
if let Some(key) = cap.get(2) {
|
||||
results.push(
|
||||
match cap.get(3).unwrap().as_str() {
|
||||
"i" => FormatText::RequestHeader(key.as_str().to_owned()),
|
||||
"o" => FormatText::ResponseHeader(key.as_str().to_owned()),
|
||||
"e" => FormatText::EnvironHeader(key.as_str().to_owned()),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
results.push(match cap.get(3).unwrap().as_str() {
|
||||
"i" => FormatText::RequestHeader(key.as_str().to_owned()),
|
||||
"o" => FormatText::ResponseHeader(key.as_str().to_owned()),
|
||||
"e" => FormatText::EnvironHeader(key.as_str().to_owned()),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
} else {
|
||||
let m = cap.get(1).unwrap();
|
||||
results.push(
|
||||
match m.as_str() {
|
||||
"%" => FormatText::Percent,
|
||||
"a" => FormatText::RemoteAddr,
|
||||
"t" => FormatText::RequestTime,
|
||||
"P" => FormatText::Pid,
|
||||
"r" => FormatText::RequestLine,
|
||||
"s" => FormatText::ResponseStatus,
|
||||
"b" => FormatText::ResponseSize,
|
||||
"T" => FormatText::Time,
|
||||
"D" => FormatText::TimeMillis,
|
||||
_ => FormatText::Str(m.as_str().to_owned()),
|
||||
}
|
||||
);
|
||||
results.push(match m.as_str() {
|
||||
"%" => FormatText::Percent,
|
||||
"a" => FormatText::RemoteAddr,
|
||||
"t" => FormatText::RequestTime,
|
||||
"P" => FormatText::Pid,
|
||||
"r" => FormatText::RequestLine,
|
||||
"s" => FormatText::ResponseStatus,
|
||||
"b" => FormatText::ResponseSize,
|
||||
"T" => FormatText::Time,
|
||||
"D" => FormatText::TimeMillis,
|
||||
_ => FormatText::Str(m.as_str().to_owned()),
|
||||
});
|
||||
}
|
||||
}
|
||||
if idx != s.len() {
|
||||
@ -207,12 +208,10 @@ pub enum FormatText {
|
||||
}
|
||||
|
||||
impl FormatText {
|
||||
|
||||
fn render<S>(&self, fmt: &mut Formatter,
|
||||
req: &HttpRequest<S>,
|
||||
resp: &HttpResponse,
|
||||
entry_time: time::Tm) -> Result<(), fmt::Error>
|
||||
{
|
||||
fn render<S>(
|
||||
&self, fmt: &mut Formatter, req: &HttpRequest<S>, resp: &HttpResponse,
|
||||
entry_time: time::Tm,
|
||||
) -> Result<(), fmt::Error> {
|
||||
match *self {
|
||||
FormatText::Str(ref string) => fmt.write_str(string),
|
||||
FormatText::Percent => "%".fmt(fmt),
|
||||
@ -220,26 +219,33 @@ impl FormatText {
|
||||
if req.query_string().is_empty() {
|
||||
fmt.write_fmt(format_args!(
|
||||
"{} {} {:?}",
|
||||
req.method(), req.path(), req.version()))
|
||||
req.method(),
|
||||
req.path(),
|
||||
req.version()
|
||||
))
|
||||
} else {
|
||||
fmt.write_fmt(format_args!(
|
||||
"{} {}?{} {:?}",
|
||||
req.method(), req.path(), req.query_string(), req.version()))
|
||||
req.method(),
|
||||
req.path(),
|
||||
req.query_string(),
|
||||
req.version()
|
||||
))
|
||||
}
|
||||
},
|
||||
}
|
||||
FormatText::ResponseStatus => resp.status().as_u16().fmt(fmt),
|
||||
FormatText::ResponseSize => resp.response_size().fmt(fmt),
|
||||
FormatText::Pid => unsafe{libc::getpid().fmt(fmt)},
|
||||
FormatText::Pid => unsafe { libc::getpid().fmt(fmt) },
|
||||
FormatText::Time => {
|
||||
let rt = time::now() - entry_time;
|
||||
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000_000.0;
|
||||
fmt.write_fmt(format_args!("{:.6}", rt))
|
||||
},
|
||||
}
|
||||
FormatText::TimeMillis => {
|
||||
let rt = time::now() - entry_time;
|
||||
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000.0;
|
||||
fmt.write_fmt(format_args!("{:.6}", rt))
|
||||
},
|
||||
}
|
||||
FormatText::RemoteAddr => {
|
||||
if let Some(remote) = req.connection_info().remote() {
|
||||
return remote.fmt(fmt);
|
||||
@ -247,14 +253,17 @@ impl FormatText {
|
||||
"-".fmt(fmt)
|
||||
}
|
||||
}
|
||||
FormatText::RequestTime => {
|
||||
entry_time.strftime("[%d/%b/%Y:%H:%M:%S %z]")
|
||||
.unwrap()
|
||||
.fmt(fmt)
|
||||
}
|
||||
FormatText::RequestTime => entry_time
|
||||
.strftime("[%d/%b/%Y:%H:%M:%S %z]")
|
||||
.unwrap()
|
||||
.fmt(fmt),
|
||||
FormatText::RequestHeader(ref name) => {
|
||||
let s = if let Some(val) = req.headers().get(name) {
|
||||
if let Ok(s) = val.to_str() { s } else { "-" }
|
||||
if let Ok(s) = val.to_str() {
|
||||
s
|
||||
} else {
|
||||
"-"
|
||||
}
|
||||
} else {
|
||||
"-"
|
||||
};
|
||||
@ -262,7 +271,11 @@ impl FormatText {
|
||||
}
|
||||
FormatText::ResponseHeader(ref name) => {
|
||||
let s = if let Some(val) = resp.headers().get(name) {
|
||||
if let Ok(s) = val.to_str() { s } else { "-" }
|
||||
if let Ok(s) = val.to_str() {
|
||||
s
|
||||
} else {
|
||||
"-"
|
||||
}
|
||||
} else {
|
||||
"-"
|
||||
};
|
||||
@ -279,8 +292,7 @@ impl FormatText {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct FormatDisplay<'a>(
|
||||
&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
|
||||
pub(crate) struct FormatDisplay<'a>(&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
|
||||
|
||||
impl<'a> fmt::Display for FormatDisplay<'a> {
|
||||
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
@ -291,19 +303,27 @@ impl<'a> fmt::Display for FormatDisplay<'a> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use http::header::{self, HeaderMap};
|
||||
use http::{Method, StatusCode, Uri, Version};
|
||||
use std::str::FromStr;
|
||||
use time;
|
||||
use http::{Method, Version, StatusCode, Uri};
|
||||
use http::header::{self, HeaderMap};
|
||||
|
||||
#[test]
|
||||
fn test_logger() {
|
||||
let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test");
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB"));
|
||||
headers.insert(
|
||||
header::USER_AGENT,
|
||||
header::HeaderValue::from_static("ACTIX-WEB"),
|
||||
);
|
||||
let mut req = HttpRequest::new(
|
||||
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
|
||||
Method::GET,
|
||||
Uri::from_str("/").unwrap(),
|
||||
Version::HTTP_11,
|
||||
headers,
|
||||
None,
|
||||
);
|
||||
let resp = HttpResponse::build(StatusCode::OK)
|
||||
.header("X-Test", "ttt")
|
||||
.force_close()
|
||||
@ -333,10 +353,20 @@ mod tests {
|
||||
let format = Format::default();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB"));
|
||||
headers.insert(
|
||||
header::USER_AGENT,
|
||||
header::HeaderValue::from_static("ACTIX-WEB"),
|
||||
);
|
||||
let req = HttpRequest::new(
|
||||
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
|
||||
let resp = HttpResponse::build(StatusCode::OK).force_close().finish();
|
||||
Method::GET,
|
||||
Uri::from_str("/").unwrap(),
|
||||
Version::HTTP_11,
|
||||
headers,
|
||||
None,
|
||||
);
|
||||
let resp = HttpResponse::build(StatusCode::OK)
|
||||
.force_close()
|
||||
.finish();
|
||||
let entry_time = time::now();
|
||||
|
||||
let render = |fmt: &mut Formatter| {
|
||||
@ -351,9 +381,15 @@ mod tests {
|
||||
assert!(s.contains("ACTIX-WEB"));
|
||||
|
||||
let req = HttpRequest::new(
|
||||
Method::GET, Uri::from_str("/?test").unwrap(),
|
||||
Version::HTTP_11, HeaderMap::new(), None);
|
||||
let resp = HttpResponse::build(StatusCode::OK).force_close().finish();
|
||||
Method::GET,
|
||||
Uri::from_str("/?test").unwrap(),
|
||||
Version::HTTP_11,
|
||||
HeaderMap::new(),
|
||||
None,
|
||||
);
|
||||
let resp = HttpResponse::build(StatusCode::OK)
|
||||
.force_close()
|
||||
.finish();
|
||||
let entry_time = time::now();
|
||||
|
||||
let render = |fmt: &mut Formatter| {
|
||||
|
@ -7,19 +7,19 @@ use httpresponse::HttpResponse;
|
||||
|
||||
mod logger;
|
||||
|
||||
#[cfg(feature = "session")]
|
||||
mod session;
|
||||
mod defaultheaders;
|
||||
mod errhandlers;
|
||||
pub mod cors;
|
||||
pub mod csrf;
|
||||
pub use self::logger::Logger;
|
||||
pub use self::errhandlers::ErrorHandlers;
|
||||
mod defaultheaders;
|
||||
mod errhandlers;
|
||||
#[cfg(feature = "session")]
|
||||
mod session;
|
||||
pub use self::defaultheaders::DefaultHeaders;
|
||||
pub use self::errhandlers::ErrorHandlers;
|
||||
pub use self::logger::Logger;
|
||||
|
||||
#[cfg(feature = "session")]
|
||||
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage,
|
||||
CookieSessionError, CookieSessionBackend};
|
||||
pub use self::session::{CookieSessionBackend, CookieSessionError, RequestSession,
|
||||
Session, SessionBackend, SessionImpl, SessionStorage};
|
||||
|
||||
/// Middleware start result
|
||||
pub enum Started {
|
||||
@ -29,7 +29,7 @@ pub enum Started {
|
||||
/// handler execution halts.
|
||||
Response(HttpResponse),
|
||||
/// Execution completed, runs future to completion.
|
||||
Future(Box<Future<Item=Option<HttpResponse>, Error=Error>>),
|
||||
Future(Box<Future<Item = Option<HttpResponse>, Error = Error>>),
|
||||
}
|
||||
|
||||
/// Middleware execution result
|
||||
@ -37,7 +37,7 @@ pub enum Response {
|
||||
/// New http response got generated
|
||||
Done(HttpResponse),
|
||||
/// Result is a future that resolves to a new http response
|
||||
Future(Box<Future<Item=HttpResponse, Error=Error>>),
|
||||
Future(Box<Future<Item = HttpResponse, Error = Error>>),
|
||||
}
|
||||
|
||||
/// Middleware finish result
|
||||
@ -45,13 +45,12 @@ pub enum Finished {
|
||||
/// Execution completed
|
||||
Done,
|
||||
/// Execution completed, but run future to completion
|
||||
Future(Box<Future<Item=(), Error=Error>>),
|
||||
Future(Box<Future<Item = (), Error = Error>>),
|
||||
}
|
||||
|
||||
/// Middleware definition
|
||||
#[allow(unused_variables)]
|
||||
pub trait Middleware<S>: 'static {
|
||||
|
||||
/// Method is called when request is ready. It may return
|
||||
/// future, which should resolve before next middleware get called.
|
||||
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||
@ -60,7 +59,9 @@ pub trait Middleware<S>: 'static {
|
||||
|
||||
/// Method is called when handler returns response,
|
||||
/// but before sending http message to peer.
|
||||
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
|
||||
fn response(
|
||||
&self, req: &mut HttpRequest<S>, resp: HttpResponse
|
||||
) -> Result<Response> {
|
||||
Ok(Response::Done(resp))
|
||||
}
|
||||
|
||||
|
@ -1,21 +1,21 @@
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use cookie::{Cookie, CookieJar, Key};
|
||||
use futures::Future;
|
||||
use futures::future::{FutureResult, err as FutErr, ok as FutOk};
|
||||
use http::header::{self, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use serde_json::error::Error as JsonError;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use http::header::{self, HeaderValue};
|
||||
use time::Duration;
|
||||
use cookie::{CookieJar, Cookie, Key};
|
||||
use futures::Future;
|
||||
use futures::future::{FutureResult, ok as FutOk, err as FutErr};
|
||||
|
||||
use error::{Result, Error, ResponseError};
|
||||
use error::{Error, ResponseError, Result};
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
use middleware::{Middleware, Started, Response};
|
||||
use middleware::{Middleware, Response, Started};
|
||||
|
||||
/// The helper trait to obtain your session data from a request.
|
||||
///
|
||||
@ -40,14 +40,13 @@ pub trait RequestSession {
|
||||
}
|
||||
|
||||
impl<S> RequestSession for HttpRequest<S> {
|
||||
|
||||
fn session(&mut self) -> Session {
|
||||
if let Some(s_impl) = self.extensions().get_mut::<Arc<SessionImplBox>>() {
|
||||
if let Some(s) = Arc::get_mut(s_impl) {
|
||||
return Session(s.0.as_mut())
|
||||
return Session(s.0.as_mut());
|
||||
}
|
||||
}
|
||||
Session(unsafe{&mut DUMMY})
|
||||
Session(unsafe { &mut DUMMY })
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,7 +75,6 @@ impl<S> RequestSession for HttpRequest<S> {
|
||||
pub struct Session<'a>(&'a mut SessionImpl);
|
||||
|
||||
impl<'a> Session<'a> {
|
||||
|
||||
/// Get a `value` from the session.
|
||||
pub fn get<T: Deserialize<'a>>(&'a self, key: &str) -> Result<Option<T>> {
|
||||
if let Some(s) = self.0.get(key) {
|
||||
@ -136,24 +134,25 @@ impl<S, T: SessionBackend<S>> SessionStorage<T, S> {
|
||||
}
|
||||
|
||||
impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
|
||||
|
||||
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
|
||||
let mut req = req.clone();
|
||||
|
||||
let fut = self.0.from_request(&mut req)
|
||||
.then(move |res| {
|
||||
match res {
|
||||
Ok(sess) => {
|
||||
req.extensions().insert(Arc::new(SessionImplBox(Box::new(sess))));
|
||||
FutOk(None)
|
||||
},
|
||||
Err(err) => FutErr(err)
|
||||
let fut = self.0
|
||||
.from_request(&mut req)
|
||||
.then(move |res| match res {
|
||||
Ok(sess) => {
|
||||
req.extensions()
|
||||
.insert(Arc::new(SessionImplBox(Box::new(sess))));
|
||||
FutOk(None)
|
||||
}
|
||||
Err(err) => FutErr(err),
|
||||
});
|
||||
Ok(Started::Future(Box::new(fut)))
|
||||
}
|
||||
|
||||
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
|
||||
fn response(
|
||||
&self, req: &mut HttpRequest<S>, resp: HttpResponse
|
||||
) -> Result<Response> {
|
||||
if let Some(s_box) = req.extensions().remove::<Arc<SessionImplBox>>() {
|
||||
s_box.0.write(resp)
|
||||
} else {
|
||||
@ -165,7 +164,6 @@ impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
|
||||
/// A simple key-value storage interface that is internally used by `Session`.
|
||||
#[doc(hidden)]
|
||||
pub trait SessionImpl: 'static {
|
||||
|
||||
fn get(&self, key: &str) -> Option<&str>;
|
||||
|
||||
fn set(&mut self, key: &str, value: String);
|
||||
@ -182,7 +180,7 @@ pub trait SessionImpl: 'static {
|
||||
#[doc(hidden)]
|
||||
pub trait SessionBackend<S>: Sized + 'static {
|
||||
type Session: SessionImpl;
|
||||
type ReadFuture: Future<Item=Self::Session, Error=Error>;
|
||||
type ReadFuture: Future<Item = Self::Session, Error = Error>;
|
||||
|
||||
/// Parse the session from request and load data from a storage backend.
|
||||
fn from_request(&self, request: &mut HttpRequest<S>) -> Self::ReadFuture;
|
||||
@ -194,8 +192,9 @@ struct DummySessionImpl;
|
||||
static mut DUMMY: DummySessionImpl = DummySessionImpl;
|
||||
|
||||
impl SessionImpl for DummySessionImpl {
|
||||
|
||||
fn get(&self, _: &str) -> Option<&str> { None }
|
||||
fn get(&self, _: &str) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
fn set(&mut self, _: &str, _: String) {}
|
||||
fn remove(&mut self, _: &str) {}
|
||||
fn clear(&mut self) {}
|
||||
@ -215,17 +214,16 @@ pub struct CookieSession {
|
||||
#[derive(Fail, Debug)]
|
||||
pub enum CookieSessionError {
|
||||
/// Size of the serialized session is greater than 4000 bytes.
|
||||
#[fail(display="Size of the serialized session is greater than 4000 bytes.")]
|
||||
#[fail(display = "Size of the serialized session is greater than 4000 bytes.")]
|
||||
Overflow,
|
||||
/// Fail to serialize session.
|
||||
#[fail(display="Fail to serialize session")]
|
||||
#[fail(display = "Fail to serialize session")]
|
||||
Serialize(JsonError),
|
||||
}
|
||||
|
||||
impl ResponseError for CookieSessionError {}
|
||||
|
||||
impl SessionImpl for CookieSession {
|
||||
|
||||
fn get(&self, key: &str) -> Option<&str> {
|
||||
if let Some(s) = self.state.get(key) {
|
||||
Some(s)
|
||||
@ -259,7 +257,7 @@ impl SessionImpl for CookieSession {
|
||||
|
||||
enum CookieSecurity {
|
||||
Signed,
|
||||
Private
|
||||
Private,
|
||||
}
|
||||
|
||||
struct CookieSessionInner {
|
||||
@ -273,7 +271,6 @@ struct CookieSessionInner {
|
||||
}
|
||||
|
||||
impl CookieSessionInner {
|
||||
|
||||
fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner {
|
||||
CookieSessionInner {
|
||||
security,
|
||||
@ -286,11 +283,13 @@ impl CookieSessionInner {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_cookie(&self, resp: &mut HttpResponse, state: &HashMap<String, String>) -> Result<()> {
|
||||
let value = serde_json::to_string(&state)
|
||||
.map_err(CookieSessionError::Serialize)?;
|
||||
fn set_cookie(
|
||||
&self, resp: &mut HttpResponse, state: &HashMap<String, String>
|
||||
) -> Result<()> {
|
||||
let value =
|
||||
serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?;
|
||||
if value.len() > 4064 {
|
||||
return Err(CookieSessionError::Overflow.into())
|
||||
return Err(CookieSessionError::Overflow.into());
|
||||
}
|
||||
|
||||
let mut cookie = Cookie::new(self.name.clone(), value);
|
||||
@ -330,7 +329,9 @@ impl CookieSessionInner {
|
||||
|
||||
let cookie_opt = match self.security {
|
||||
CookieSecurity::Signed => jar.signed(&self.key).get(&self.name),
|
||||
CookieSecurity::Private => jar.private(&self.key).get(&self.name),
|
||||
CookieSecurity::Private => {
|
||||
jar.private(&self.key).get(&self.name)
|
||||
}
|
||||
};
|
||||
if let Some(cookie) = cookie_opt {
|
||||
if let Ok(val) = serde_json::from_str(cookie.value()) {
|
||||
@ -347,20 +348,24 @@ impl CookieSessionInner {
|
||||
/// Use cookies for session storage.
|
||||
///
|
||||
/// `CookieSessionBackend` creates sessions which are limited to storing
|
||||
/// fewer than 4000 bytes of data (as the payload must fit into a single cookie).
|
||||
/// An Internal Server Error is generated if the session contains more than 4000 bytes.
|
||||
/// fewer than 4000 bytes of data (as the payload must fit into a single
|
||||
/// cookie). An Internal Server Error is generated if the session contains more
|
||||
/// than 4000 bytes.
|
||||
///
|
||||
/// A cookie may have a security policy of *signed* or *private*. Each has a respective `CookieSessionBackend` constructor.
|
||||
/// A cookie may have a security policy of *signed* or *private*. Each has a
|
||||
/// respective `CookieSessionBackend` constructor.
|
||||
///
|
||||
/// A *signed* cookie is stored on the client as plaintext alongside
|
||||
/// a signature such that the cookie may be viewed but not modified by the client.
|
||||
/// a signature such that the cookie may be viewed but not modified by the
|
||||
/// client.
|
||||
///
|
||||
/// A *private* cookie is stored on the client as encrypted text
|
||||
/// such that it may neither be viewed nor modified by the client.
|
||||
///
|
||||
/// The constructors take a key as an argument.
|
||||
/// This is the private key for cookie session - when this value is changed, all session data is lost.
|
||||
/// The constructors will panic if the key is less than 32 bytes in length.
|
||||
/// This is the private key for cookie session - when this value is changed,
|
||||
/// all session data is lost. The constructors will panic if the key is less
|
||||
/// than 32 bytes in length.
|
||||
///
|
||||
///
|
||||
/// # Example
|
||||
@ -380,21 +385,24 @@ impl CookieSessionInner {
|
||||
pub struct CookieSessionBackend(Rc<CookieSessionInner>);
|
||||
|
||||
impl CookieSessionBackend {
|
||||
|
||||
/// Construct new *signed* `CookieSessionBackend` instance.
|
||||
///
|
||||
/// Panics if key length is less than 32 bytes.
|
||||
pub fn signed(key: &[u8]) -> CookieSessionBackend {
|
||||
CookieSessionBackend(
|
||||
Rc::new(CookieSessionInner::new(key, CookieSecurity::Signed)))
|
||||
CookieSessionBackend(Rc::new(CookieSessionInner::new(
|
||||
key,
|
||||
CookieSecurity::Signed,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Construct new *private* `CookieSessionBackend` instance.
|
||||
///
|
||||
/// Panics if key length is less than 32 bytes.
|
||||
pub fn private(key: &[u8]) -> CookieSessionBackend {
|
||||
CookieSessionBackend(
|
||||
Rc::new(CookieSessionInner::new(key, CookieSecurity::Private)))
|
||||
CookieSessionBackend(Rc::new(CookieSessionInner::new(
|
||||
key,
|
||||
CookieSecurity::Private,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Sets the `path` field in the session cookie being built.
|
||||
@ -432,17 +440,15 @@ impl CookieSessionBackend {
|
||||
}
|
||||
|
||||
impl<S> SessionBackend<S> for CookieSessionBackend {
|
||||
|
||||
type Session = CookieSession;
|
||||
type ReadFuture = FutureResult<CookieSession, Error>;
|
||||
|
||||
fn from_request(&self, req: &mut HttpRequest<S>) -> Self::ReadFuture {
|
||||
let state = self.0.load(req);
|
||||
FutOk(
|
||||
CookieSession {
|
||||
changed: false,
|
||||
inner: Rc::clone(&self.0),
|
||||
state,
|
||||
})
|
||||
FutOk(CookieSession {
|
||||
changed: false,
|
||||
inner: Rc::clone(&self.0),
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user