diff --git a/src/lib.rs b/src/lib.rs index ab3f196d..a50debd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,6 +69,7 @@ pub mod server; pub mod service; pub mod ssl; pub mod stream; +pub mod timeout; pub mod timer; #[derive(Copy, Clone, Debug)] diff --git a/src/timeout.rs b/src/timeout.rs new file mode 100644 index 00000000..1262e6f0 --- /dev/null +++ b/src/timeout.rs @@ -0,0 +1,162 @@ +//! Tower middleware that applies a timeout to requests. +//! +//! If the response does not complete within the specified timeout, the response +//! will be aborted. +use std::fmt; +use std::time::Duration; + +use futures::{Async, Future, Poll}; +use tokio_timer::{clock, Delay}; + +use service::{NewService, Service}; + +/// Applies a timeout to requests. +#[derive(Debug)] +pub struct Timeout { + inner: T, + timeout: Duration, +} + +/// Timeout error +pub enum TimeoutError { + /// Service error + Service(E), + /// Service call timeout + Timeout, +} + +impl fmt::Debug for TimeoutError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TimeoutError::Service(e) => write!(f, "TimeoutError::Service({:?})", e), + TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"), + } + } +} + +/// `Timeout` response future +#[derive(Debug)] +pub struct TimeoutFut { + fut: T::Future, + timeout: Duration, +} + +impl Timeout +where + T: NewService + Clone, +{ + pub fn new(timeout: Duration, inner: T) -> Self { + Timeout { inner, timeout } + } +} + +impl NewService for Timeout +where + T: NewService + Clone, +{ + type Request = T::Request; + type Response = T::Response; + type Error = TimeoutError; + type InitError = T::InitError; + type Service = TimeoutService; + type Future = TimeoutFut; + + fn new_service(&self) -> Self::Future { + TimeoutFut { + fut: self.inner.new_service(), + timeout: self.timeout.clone(), + } + } +} + +impl Future for TimeoutFut +where + T: NewService, +{ + type Item = TimeoutService; + type Error = T::InitError; + + fn poll(&mut self) -> Poll { + let service = try_ready!(self.fut.poll()); + Ok(Async::Ready(TimeoutService::new(self.timeout, service))) + } +} + +/// Applies a timeout to requests. +#[derive(Debug)] +pub struct TimeoutService { + inner: T, + timeout: Duration, +} + +impl TimeoutService { + pub fn new(timeout: Duration, inner: T) -> Self { + TimeoutService { inner, timeout } + } +} + +impl Clone for TimeoutService +where + T: Clone, +{ + fn clone(&self) -> Self { + TimeoutService { + inner: self.inner.clone(), + timeout: self.timeout, + } + } +} + +impl Service for TimeoutService +where + T: Service, +{ + type Request = T::Request; + type Response = T::Response; + type Error = TimeoutError; + type Future = TimeoutServiceResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.inner + .poll_ready() + .map_err(|e| TimeoutError::Service(e)) + } + + fn call(&mut self, request: Self::Request) -> Self::Future { + TimeoutServiceResponse { + fut: self.inner.call(request), + sleep: Delay::new(clock::now() + self.timeout), + } + } +} + +/// `TimeoutService` response future +#[derive(Debug)] +pub struct TimeoutServiceResponse { + fut: T::Future, + sleep: Delay, +} + +impl Future for TimeoutServiceResponse +where + T: Service, +{ + type Item = T::Response; + type Error = TimeoutError; + + fn poll(&mut self) -> Poll { + // First, try polling the future + match self.fut.poll() { + Ok(Async::Ready(v)) => return Ok(Async::Ready(v)), + Ok(Async::NotReady) => {} + Err(e) => return Err(TimeoutError::Service(e)), + } + + // Now check the sleep + match self.sleep.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(_)) => Err(TimeoutError::Timeout), + Err(_) => Err(TimeoutError::Timeout), + } + } +}