use std::convert::Infallible; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use actix_service::{IntoService, Service, Transform}; use futures::future::{ok, Ready}; use super::counter::{Counter, CounterGuard}; /// InFlight - new service for service that can limit number of in-flight /// async requests. /// /// Default number of in-flight requests is 15 pub struct InFlight { max_inflight: usize, } impl InFlight { pub fn new(max: usize) -> Self { Self { max_inflight: max } } } impl Default for InFlight { fn default() -> Self { Self::new(15) } } impl Transform for InFlight where S: Service, S::Future: Unpin, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = Infallible; type Transform = InFlightService; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(InFlightService::new(self.max_inflight, service)) } } pub struct InFlightService { count: Counter, service: S, } impl InFlightService where S: Service, { pub fn new(max: usize, service: U) -> Self where U: IntoService, { Self { count: Counter::new(max), service: service.into_service(), } } } impl Service for InFlightService where T: Service, T::Future: Unpin, { type Request = T::Request; type Response = T::Response; type Error = T::Error; type Future = InFlightServiceResponse; fn poll_ready(&mut self, cx: &mut Context) -> Poll> { if let Poll::Pending = self.service.poll_ready(cx)? { Poll::Pending } else if !self.count.available(cx) { log::trace!("InFlight limit exceeded"); Poll::Pending } else { Poll::Ready(Ok(())) } } fn call(&mut self, req: T::Request) -> Self::Future { InFlightServiceResponse { fut: self.service.call(req), _guard: self.count.get(), } } } #[doc(hidden)] pub struct InFlightServiceResponse { fut: T::Future, _guard: CounterGuard, } impl Future for InFlightServiceResponse where T::Future: Unpin, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { Pin::new(&mut self.get_mut().fut).poll(cx) } } #[cfg(test)] mod tests { use std::task::{Context, Poll}; use std::time::Duration; use super::*; use actix_service::{apply, factory_fn, Service, ServiceFactory}; use futures::future::{lazy, ok, FutureExt, LocalBoxFuture}; struct SleepService(Duration); impl Service for SleepService { type Request = (); type Response = (); type Error = (); type Future = LocalBoxFuture<'static, Result<(), ()>>; fn poll_ready(&mut self, _: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _: ()) -> Self::Future { tokio_timer::delay_for(self.0) .then(|_| ok::<_, ()>(())) .boxed_local() } } #[test] fn test_transform() { let wait_time = Duration::from_millis(50); let _ = actix_rt::System::new("test").block_on(async { let mut srv = InFlightService::new(1, SleepService(wait_time)); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let res = srv.call(()); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending); let _ = res.await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); }); } #[test] fn test_newtransform() { let wait_time = Duration::from_millis(50); actix_rt::System::new("test").block_on(async { let srv = apply(InFlight::new(1), factory_fn(|| ok(SleepService(wait_time)))); let mut srv = srv.new_service(&()).await.unwrap(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let res = srv.call(()); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending); let _ = res.await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); }); } }