//! Custom handlers service for responses.
use std::rc::Rc;

use actix_service::{Service, Transform};
use futures::future::{err, ok, Either, Future, FutureResult};
use futures::Poll;
use hashbrown::HashMap;

use crate::dev::{ServiceRequest, ServiceResponse};
use crate::error::{Error, Result};
use crate::http::StatusCode;

/// Error handler response
pub enum ErrorHandlerResponse<B> {
    /// New http response got generated
    Response(ServiceResponse<B>),
    /// Result is a future that resolves to a new http response
    Future(Box<Future<Item = ServiceResponse<B>, Error = Error>>),
}

type ErrorHandler<B> = Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;

/// `Middleware` for allowing custom handlers for responses.
///
/// You can use `ErrorHandlers::handler()` method  to register a custom error
/// handler for specific status code. You can modify existing response or
/// create completely new one.
///
/// ## Example
///
/// ```rust
/// use actix_web::middleware::errhandlers::{ErrorHandlers, ErrorHandlerResponse};
/// use actix_web::{web, http, dev, App, HttpRequest, HttpResponse, Result};
///
/// fn render_500<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
///     res.response_mut()
///        .headers_mut()
///        .insert(http::header::CONTENT_TYPE, http::HeaderValue::from_static("Error"));
///     Ok(ErrorHandlerResponse::Response(res))
/// }
///
/// fn main() {
///     let app = App::new()
///         .wrap(
///             ErrorHandlers::new()
///                 .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500),
///         )
///         .service(web::resource("/test")
///             .route(web::get().to(|| HttpResponse::Ok()))
///             .route(web::head().to(|| HttpResponse::MethodNotAllowed())
///         ));
/// }
/// ```
pub struct ErrorHandlers<B> {
    handlers: Rc<HashMap<StatusCode, Box<ErrorHandler<B>>>>,
}

impl<B> Default for ErrorHandlers<B> {
    fn default() -> Self {
        ErrorHandlers {
            handlers: Rc::new(HashMap::new()),
        }
    }
}

impl<B> ErrorHandlers<B> {
    /// Construct new `ErrorHandlers` instance
    pub fn new() -> Self {
        ErrorHandlers::default()
    }

    /// Register error handler for specified status code
    pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
    where
        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
    {
        Rc::get_mut(&mut self.handlers)
            .unwrap()
            .insert(status, Box::new(handler));
        self
    }
}

impl<S, B> Transform<S> for ErrorHandlers<B>
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = ErrorHandlersMiddleware<S, B>;
    type Future = FutureResult<Self::Transform, Self::InitError>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(ErrorHandlersMiddleware {
            service,
            handlers: self.handlers.clone(),
        })
    }
}

#[doc(hidden)]
pub struct ErrorHandlersMiddleware<S, B> {
    service: S,
    handlers: Rc<HashMap<StatusCode, Box<ErrorHandler<B>>>>,
}

impl<S, B> Service for ErrorHandlersMiddleware<S, B>
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = Box<Future<Item = Self::Response, Error = Self::Error>>;

    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
        self.service.poll_ready()
    }

    fn call(&mut self, req: ServiceRequest) -> Self::Future {
        let handlers = self.handlers.clone();

        Box::new(self.service.call(req).and_then(move |res| {
            if let Some(handler) = handlers.get(&res.status()) {
                match handler(res) {
                    Ok(ErrorHandlerResponse::Response(res)) => Either::A(ok(res)),
                    Ok(ErrorHandlerResponse::Future(fut)) => Either::B(fut),
                    Err(e) => Either::A(err(e)),
                }
            } else {
                Either::A(ok(res))
            }
        }))
    }
}

#[cfg(test)]
mod tests {
    use actix_service::FnService;
    use futures::future::ok;

    use super::*;
    use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
    use crate::test::{self, TestRequest};
    use crate::HttpResponse;

    fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
        res.response_mut()
            .headers_mut()
            .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
        Ok(ErrorHandlerResponse::Response(res))
    }

    #[test]
    fn test_handler() {
        let srv = FnService::new(|req: ServiceRequest| {
            req.into_response(HttpResponse::InternalServerError().finish())
        });

        let mut mw = test::block_on(
            ErrorHandlers::new()
                .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500)
                .new_transform(srv),
        )
        .unwrap();

        let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
    }

    fn render_500_async<B: 'static>(
        mut res: ServiceResponse<B>,
    ) -> Result<ErrorHandlerResponse<B>> {
        res.response_mut()
            .headers_mut()
            .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
        Ok(ErrorHandlerResponse::Future(Box::new(ok(res))))
    }

    #[test]
    fn test_handler_async() {
        let srv = FnService::new(|req: ServiceRequest| {
            req.into_response(HttpResponse::InternalServerError().finish())
        });

        let mut mw = test::block_on(
            ErrorHandlers::new()
                .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async)
                .new_transform(srv),
        )
        .unwrap();

        let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
    }
}