1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-12-01 02:54:36 +01:00
actix-web/src/middleware/defaultheaders.rs
2019-03-23 21:29:16 -07:00

203 lines
5.8 KiB
Rust

//! Middleware for setting default response headers
use std::rc::Rc;
use actix_service::{Service, Transform};
use futures::future::{ok, FutureResult};
use futures::{Future, Poll};
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use crate::http::{HeaderMap, HttpTryFrom};
use crate::service::{ServiceRequest, ServiceResponse};
/// `Middleware` for setting default response headers.
///
/// This middleware does not set header if response headers already contains it.
///
/// ```rust
/// use actix_web::{web, http, middleware, App, HttpResponse};
///
/// fn main() {
/// let app = App::new()
/// .middleware(middleware::DefaultHeaders::new().header("X-Version", "0.2"))
/// .service(
/// web::resource("/test")
/// .route(web::get().to(|| HttpResponse::Ok()))
/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed()))
/// );
/// }
/// ```
#[derive(Clone)]
pub struct DefaultHeaders {
inner: Rc<Inner>,
}
struct Inner {
ct: bool,
headers: HeaderMap,
}
impl Default for DefaultHeaders {
fn default() -> Self {
DefaultHeaders {
inner: Rc::new(Inner {
ct: false,
headers: HeaderMap::new(),
}),
}
}
}
impl DefaultHeaders {
/// Construct `DefaultHeaders` middleware.
pub fn new() -> DefaultHeaders {
DefaultHeaders::default()
}
/// Set a header.
#[inline]
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>,
{
#[allow(clippy::match_wild_err_arm)]
match HeaderName::try_from(key) {
Ok(key) => match HeaderValue::try_from(value) {
Ok(value) => {
Rc::get_mut(&mut self.inner)
.expect("Multiple copies exist")
.headers
.append(key, value);
}
Err(_) => panic!("Can not create header value"),
},
Err(_) => panic!("Can not create header name"),
}
self
}
/// Set *CONTENT-TYPE* header if response does not contain this header.
pub fn content_type(mut self) -> Self {
Rc::get_mut(&mut self.inner)
.expect("Multiple copies exist")
.ct = true;
self
}
}
impl<S, P, B> Transform<S> for DefaultHeaders
where
S: Service<Request = ServiceRequest<P>, Response = ServiceResponse<B>>,
S::Future: 'static,
{
type Request = ServiceRequest<P>;
type Response = ServiceResponse<B>;
type Error = S::Error;
type InitError = ();
type Transform = DefaultHeadersMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future {
ok(DefaultHeadersMiddleware {
service,
inner: self.inner.clone(),
})
}
}
pub struct DefaultHeadersMiddleware<S> {
service: S,
inner: Rc<Inner>,
}
impl<S, P, B> Service for DefaultHeadersMiddleware<S>
where
S: Service<Request = ServiceRequest<P>, Response = ServiceResponse<B>>,
S::Future: 'static,
{
type Request = ServiceRequest<P>;
type Response = ServiceResponse<B>;
type Error = S::Error;
type Future = Box<Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
}
fn call(&mut self, req: ServiceRequest<P>) -> Self::Future {
let inner = self.inner.clone();
Box::new(self.service.call(req).map(move |mut res| {
// set response headers
for (key, value) in inner.headers.iter() {
if !res.headers().contains_key(key) {
res.headers_mut().insert(key, value.clone());
}
}
// default content-type
if inner.ct && !res.headers().contains_key(CONTENT_TYPE) {
res.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
res
}))
}
}
#[cfg(test)]
mod tests {
use actix_service::FnService;
use super::*;
use crate::dev::ServiceRequest;
use crate::http::header::CONTENT_TYPE;
use crate::test::{block_on, ok_service, TestRequest};
use crate::HttpResponse;
#[test]
fn test_default_headers() {
let mut mw = block_on(
DefaultHeaders::new()
.header(CONTENT_TYPE, "0001")
.new_transform(ok_service()),
)
.unwrap();
let req = TestRequest::default().to_service();
let resp = block_on(mw.call(req)).unwrap();
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let req = TestRequest::default().to_service();
let srv = FnService::new(|req: ServiceRequest<_>| {
req.into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())
});
let mut mw = block_on(
DefaultHeaders::new()
.header(CONTENT_TYPE, "0001")
.new_transform(srv),
)
.unwrap();
let resp = block_on(mw.call(req)).unwrap();
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
}
#[test]
fn test_content_type() {
let srv = FnService::new(|req: ServiceRequest<_>| {
req.into_response(HttpResponse::Ok().finish())
});
let mut mw =
block_on(DefaultHeaders::new().content_type().new_transform(srv)).unwrap();
let req = TestRequest::default().to_service();
let resp = block_on(mw.call(req)).unwrap();
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
}
}