use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use super::{Service, ServiceFactory}; use crate::cell::Cell; /// Service for the `then` combinator, chaining a computation onto the end of /// another service. /// /// This is created by the `ServiceExt::then` method. pub struct ThenService { a: A, b: Cell, } impl ThenService { /// Create new `.then()` combinator pub(crate) fn new(a: A, b: B) -> ThenService where A: Service, B: Service, Error = A::Error>, { Self { a, b: Cell::new(b) } } } impl Clone for ThenService where A: Clone, { fn clone(&self) -> Self { ThenService { a: self.a.clone(), b: self.b.clone(), } } } impl Service for ThenService where A: Service, B: Service, Error = A::Error>, { type Request = A::Request; type Response = B::Response; type Error = B::Error; type Future = ThenServiceResponse; fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { let not_ready = !self.a.poll_ready(ctx)?.is_ready(); if !self.b.get_mut().poll_ready(ctx)?.is_ready() || not_ready { Poll::Pending } else { Poll::Ready(Ok(())) } } fn call(&mut self, req: A::Request) -> Self::Future { ThenServiceResponse::new(self.a.call(req), self.b.clone()) } } #[pin_project::pin_project] pub struct ThenServiceResponse where A: Service, B: Service>, { b: Cell, #[pin] fut_b: Option, #[pin] fut_a: Option, } impl ThenServiceResponse where A: Service, B: Service>, { fn new(a: A::Future, b: Cell) -> Self { ThenServiceResponse { b, fut_a: Some(a), fut_b: None, } } } impl Future for ThenServiceResponse where A: Service, B: Service>, { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.as_mut().project(); loop { if let Some(fut) = this.fut_b.as_pin_mut() { return fut.poll(cx); } match this .fut_a .as_pin_mut() .expect("Bug in actix-service") .poll(cx) { Poll::Ready(r) => { this = self.as_mut().project(); this.fut_b.set(Some(this.b.get_mut().call(r))); } Poll::Pending => return Poll::Pending, } } } } /// `.then()` service factory combinator pub struct ThenServiceFactory { a: A, b: B, } impl ThenServiceFactory where A: ServiceFactory, A::Config: Clone, B: ServiceFactory< Config = A::Config, Request = Result, Error = A::Error, InitError = A::InitError, >, { /// Create new `AndThen` combinator pub(crate) fn new(a: A, b: B) -> Self { Self { a, b } } } impl ServiceFactory for ThenServiceFactory where A: ServiceFactory, A::Config: Clone, B: ServiceFactory< Config = A::Config, Request = Result, Error = A::Error, InitError = A::InitError, >, { type Request = A::Request; type Response = B::Response; type Error = A::Error; type Config = A::Config; type Service = ThenService; type InitError = A::InitError; type Future = ThenServiceFactoryResponse; fn new_service(&self, cfg: A::Config) -> Self::Future { ThenServiceFactoryResponse::new( self.a.new_service(cfg.clone()), self.b.new_service(cfg), ) } } impl Clone for ThenServiceFactory where A: Clone, B: Clone, { fn clone(&self) -> Self { Self { a: self.a.clone(), b: self.b.clone(), } } } #[pin_project::pin_project] pub struct ThenServiceFactoryResponse where A: ServiceFactory, B: ServiceFactory< Config = A::Config, Request = Result, Error = A::Error, InitError = A::InitError, >, { #[pin] fut_b: B::Future, #[pin] fut_a: A::Future, a: Option, b: Option, } impl ThenServiceFactoryResponse where A: ServiceFactory, B: ServiceFactory< Config = A::Config, Request = Result, Error = A::Error, InitError = A::InitError, >, { fn new(fut_a: A::Future, fut_b: B::Future) -> Self { Self { fut_a, fut_b, a: None, b: None, } } } impl Future for ThenServiceFactoryResponse where A: ServiceFactory, B: ServiceFactory< Config = A::Config, Request = Result, Error = A::Error, InitError = A::InitError, >, { type Output = Result, A::InitError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if this.a.is_none() { if let Poll::Ready(service) = this.fut_a.poll(cx)? { *this.a = Some(service); } } if this.b.is_none() { if let Poll::Ready(service) = this.fut_b.poll(cx)? { *this.b = Some(service); } } if this.a.is_some() && this.b.is_some() { Poll::Ready(Ok(ThenService::new( this.a.take().unwrap(), this.b.take().unwrap(), ))) } else { Poll::Pending } } } #[cfg(test)] mod tests { use std::cell::Cell; use std::rc::Rc; use std::task::{Context, Poll}; use futures::future::{err, lazy, ok, ready, Ready}; use crate::{pipeline, pipeline_factory, Service, ServiceFactory}; #[derive(Clone)] struct Srv1(Rc>); impl Service for Srv1 { type Request = Result<&'static str, &'static str>; type Response = &'static str; type Error = (); type Future = Ready>; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { self.0.set(self.0.get() + 1); Poll::Ready(Ok(())) } fn call(&mut self, req: Result<&'static str, &'static str>) -> Self::Future { match req { Ok(msg) => ok(msg), Err(_) => err(()), } } } struct Srv2(Rc>); impl Service for Srv2 { type Request = Result<&'static str, ()>; type Response = (&'static str, &'static str); type Error = (); type Future = Ready>; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { self.0.set(self.0.get() + 1); Poll::Ready(Err(())) } fn call(&mut self, req: Result<&'static str, ()>) -> Self::Future { match req { Ok(msg) => ok((msg, "ok")), Err(()) => ok(("srv2", "err")), } } } #[actix_rt::test] async fn test_poll_ready() { let cnt = Rc::new(Cell::new(0)); let mut srv = pipeline(Srv1(cnt.clone())).then(Srv2(cnt.clone())); let res = lazy(|cx| srv.poll_ready(cx)).await; assert_eq!(res, Poll::Ready(Err(()))); assert_eq!(cnt.get(), 2); } #[actix_rt::test] async fn test_call() { let cnt = Rc::new(Cell::new(0)); let mut srv = pipeline(Srv1(cnt.clone())).then(Srv2(cnt)); let res = srv.call(Ok("srv1")).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv1", "ok"))); let res = srv.call(Err("srv")).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv2", "err"))); } #[actix_rt::test] async fn test_factory() { let cnt = Rc::new(Cell::new(0)); let cnt2 = cnt.clone(); let blank = move || ready(Ok::<_, ()>(Srv1(cnt2.clone()))); let factory = pipeline_factory(blank).then(move || ready(Ok(Srv2(cnt.clone())))); let mut srv = factory.new_service(&()).await.unwrap(); let res = srv.call(Ok("srv1")).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv1", "ok"))); let res = srv.call(Err("srv")).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv2", "err"))); } }