//! Middleware for setting default response headers use std::convert::TryFrom; use std::rc::Rc; use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::{Error as HttpError, HeaderMap}; use crate::service::{ServiceRequest, ServiceResponse}; use crate::Error; /// `Middleware` for setting default response headers. /// /// This middleware does not set header if response headers already contains it. /// /// ```rust /// use actix_web::{web, http, middleware, App, HttpResponse}; /// /// fn main() { /// let app = App::new() /// .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) /// .service( /// web::resource("/test") /// .route(web::get().to(|| HttpResponse::Ok())) /// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) /// ); /// } /// ``` #[derive(Clone)] pub struct DefaultHeaders { inner: Rc, } struct Inner { ct: bool, headers: HeaderMap, } impl Default for DefaultHeaders { fn default() -> Self { DefaultHeaders { inner: Rc::new(Inner { ct: false, headers: HeaderMap::new(), }), } } } impl DefaultHeaders { /// Construct `DefaultHeaders` middleware. pub fn new() -> DefaultHeaders { DefaultHeaders::default() } /// Set a header. #[inline] pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { #[allow(clippy::match_wild_err_arm)] match HeaderName::try_from(key) { Ok(key) => match HeaderValue::try_from(value) { Ok(value) => { Rc::get_mut(&mut self.inner) .expect("Multiple copies exist") .headers .append(key, value); } Err(_) => panic!("Can not create header value"), }, Err(_) => panic!("Can not create header name"), } self } /// Set *CONTENT-TYPE* header if response does not contain this header. pub fn content_type(mut self) -> Self { Rc::get_mut(&mut self.inner) .expect("Multiple copies exist") .ct = true; self } } impl Transform for DefaultHeaders where S: Service, Error = Error>, S::Future: 'static, { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type InitError = (); type Transform = DefaultHeadersMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(DefaultHeadersMiddleware { service, inner: self.inner.clone(), }) } } pub struct DefaultHeadersMiddleware { service: S, inner: Rc, } impl Service for DefaultHeadersMiddleware where S: Service, Error = Error>, S::Future: 'static, { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); let fut = self.service.call(req); async move { let mut res = fut.await?; // set response headers for (key, value) in inner.headers.iter() { if !res.headers().contains_key(key) { res.headers_mut().insert(key.clone(), value.clone()); } } // default content-type if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { res.headers_mut().insert( CONTENT_TYPE, HeaderValue::from_static("application/octet-stream"), ); } Ok(res) } .boxed_local() } } #[cfg(test)] mod tests { use actix_service::IntoService; use futures::future::ok; use super::*; use crate::dev::ServiceRequest; use crate::http::header::CONTENT_TYPE; use crate::test::{ok_service, TestRequest}; use crate::HttpResponse; #[actix_rt::test] async fn test_default_headers() { let mut mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") .new_transform(ok_service()) .await .unwrap(); let req = TestRequest::default().to_srv_request(); let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); let req = TestRequest::default().to_srv_request(); let srv = |req: ServiceRequest| { ok(req .into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())) }; let mut mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") .new_transform(srv.into_service()) .await .unwrap(); let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); } #[actix_rt::test] async fn test_content_type() { let srv = |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); let mut mw = DefaultHeaders::new() .content_type() .new_transform(srv.into_service()) .await .unwrap(); let req = TestRequest::default().to_srv_request(); let resp = mw.call(req).await.unwrap(); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), "application/octet-stream" ); } }