//! Service that applies a timeout to requests. //! //! If the response does not complete within the specified timeout, the response //! will be aborted. use core::future::Future; use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll}; use core::{fmt, time}; use actix_rt::time::{delay_for, Delay}; use actix_service::{IntoService, Service, Transform}; use pin_project_lite::pin_project; /// Applies a timeout to requests. #[derive(Debug)] pub struct Timeout { timeout: time::Duration, _t: PhantomData, } /// Timeout error pub enum TimeoutError { /// Service error Service(E), /// Service call timeout Timeout, } impl From for TimeoutError { fn from(err: E) -> Self { TimeoutError::Service(err) } } impl fmt::Debug for TimeoutError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TimeoutError::Service(e) => write!(f, "TimeoutError::Service({:?})", e), TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"), } } } impl fmt::Display for TimeoutError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TimeoutError::Service(e) => e.fmt(f), TimeoutError::Timeout => write!(f, "Service call timeout"), } } } impl PartialEq for TimeoutError { fn eq(&self, other: &TimeoutError) -> bool { match self { TimeoutError::Service(e1) => match other { TimeoutError::Service(e2) => e1 == e2, TimeoutError::Timeout => false, }, TimeoutError::Timeout => matches!(other, TimeoutError::Timeout), } } } impl Timeout { pub fn new(timeout: time::Duration) -> Self { Timeout { timeout, _t: PhantomData, } } } impl Clone for Timeout { fn clone(&self) -> Self { Timeout::new(self.timeout) } } impl Transform for Timeout where S: Service, { type Response = S::Response; type Error = TimeoutError; type InitError = E; type Transform = TimeoutService; type Future = TimeoutFuture; fn new_transform(&self, service: S) -> Self::Future { let service = TimeoutService { service, timeout: self.timeout, _phantom: PhantomData, }; TimeoutFuture { service: Some(service), _err: PhantomData, } } } pub struct TimeoutFuture { service: Option, _err: PhantomData, } impl Unpin for TimeoutFuture {} impl Future for TimeoutFuture { type Output = Result; fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { Poll::Ready(Ok(self.get_mut().service.take().unwrap())) } } /// Applies a timeout to requests. #[derive(Debug, Clone)] pub struct TimeoutService { service: S, timeout: time::Duration, _phantom: PhantomData, } impl TimeoutService where S: Service, { pub fn new(timeout: time::Duration, service: U) -> Self where U: IntoService, { TimeoutService { timeout, service: service.into_service(), _phantom: PhantomData, } } } impl Service for TimeoutService where S: Service, { type Response = S::Response; type Error = TimeoutError; type Future = TimeoutServiceResponse; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx).map_err(TimeoutError::Service) } fn call(&mut self, request: Req) -> Self::Future { TimeoutServiceResponse { fut: self.service.call(request), sleep: delay_for(self.timeout), } } } pin_project! { /// `TimeoutService` response future #[derive(Debug)] pub struct TimeoutServiceResponse where S: Service { #[pin] fut: S::Future, sleep: Delay, } } impl Future for TimeoutServiceResponse where S: Service, { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); // First, try polling the future if let Poll::Ready(res) = this.fut.poll(cx) { return match res { Ok(v) => Poll::Ready(Ok(v)), Err(e) => Poll::Ready(Err(TimeoutError::Service(e))), }; } // Now check the sleep Pin::new(this.sleep) .poll(cx) .map(|_| Err(TimeoutError::Timeout)) } } #[cfg(test)] mod tests { use std::task::Poll; use std::time::Duration; use super::*; use actix_service::{apply, fn_factory, Service, ServiceFactory}; use futures_util::future::{ok, FutureExt, LocalBoxFuture}; struct SleepService(Duration); impl Service<()> for SleepService { type Response = (); type Error = (); type Future = LocalBoxFuture<'static, Result<(), ()>>; actix_service::always_ready!(); fn call(&mut self, _: ()) -> Self::Future { actix_rt::time::delay_for(self.0) .then(|_| ok::<_, ()>(())) .boxed_local() } } #[actix_rt::test] async fn test_success() { let resolution = Duration::from_millis(100); let wait_time = Duration::from_millis(50); let mut timeout = TimeoutService::new(resolution, SleepService(wait_time)); assert_eq!(timeout.call(()).await, Ok(())); } #[actix_rt::test] async fn test_timeout() { let resolution = Duration::from_millis(100); let wait_time = Duration::from_millis(500); let mut timeout = TimeoutService::new(resolution, SleepService(wait_time)); assert_eq!(timeout.call(()).await, Err(TimeoutError::Timeout)); } #[actix_rt::test] async fn test_timeout_new_service() { let resolution = Duration::from_millis(100); let wait_time = Duration::from_millis(500); let timeout = apply( Timeout::new(resolution), fn_factory(|| ok::<_, ()>(SleepService(wait_time))), ); let mut srv = timeout.new_service(&()).await.unwrap(); assert_eq!(srv.call(()).await, Err(TimeoutError::Timeout)); } }