use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use super::{IntoService, IntoServiceFactory, Service, ServiceFactory}; /// Apply tranform function to a service pub fn apply_fn(service: U, f: F) -> Apply where T: Service, F: FnMut(In, &mut T) -> R, R: Future>, U: IntoService, { Apply::new(service.into_service(), f) } /// Create factory for `apply` service. pub fn apply_fn_factory( service: U, f: F, ) -> ApplyServiceFactory where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, U: IntoServiceFactory, { ApplyServiceFactory::new(service.into_factory(), f) } /// `Apply` service combinator pub struct Apply where T: Service, { service: T, f: F, r: PhantomData<(In, Out, R)>, } impl Apply where T: Service, F: FnMut(In, &mut T) -> R, R: Future>, { /// Create new `Apply` combinator fn new(service: T, f: F) -> Self { Self { service, f, r: PhantomData, } } } impl Service for Apply where T: Service, F: FnMut(In, &mut T) -> R, R: Future>, { type Request = In; type Response = Out; type Error = Err; type Future = R; fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { Poll::Ready(futures::ready!(self.service.poll_ready(ctx))) } fn call(&mut self, req: In) -> Self::Future { (self.f)(req, &mut self.service) } } /// `apply()` service factory pub struct ApplyServiceFactory where T: ServiceFactory, { service: T, f: F, r: PhantomData<(R, In, Out)>, } impl ApplyServiceFactory where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, { /// Create new `ApplyNewService` new service instance fn new(service: T, f: F) -> Self { Self { f, service, r: PhantomData, } } } impl ServiceFactory for ApplyServiceFactory where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, { type Request = In; type Response = Out; type Error = Err; type Config = T::Config; type Service = Apply; type InitError = T::InitError; type Future = ApplyServiceFactoryResponse; fn new_service(&self, cfg: &T::Config) -> Self::Future { ApplyServiceFactoryResponse::new(self.service.new_service(cfg), self.f.clone()) } } #[pin_project::pin_project] pub struct ApplyServiceFactoryResponse where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, { #[pin] fut: T::Future, f: Option, r: PhantomData<(In, Out)>, } impl ApplyServiceFactoryResponse where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, { fn new(fut: T::Future, f: F) -> Self { Self { f: Some(f), fut, r: PhantomData, } } } impl Future for ApplyServiceFactoryResponse where T: ServiceFactory, F: FnMut(In, &mut T::Service) -> R + Clone, R: Future>, { type Output = Result, T::InitError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if let Poll::Ready(svc) = this.fut.poll(cx)? { Poll::Ready(Ok(Apply::new(svc, this.f.take().unwrap()))) } else { Poll::Pending } } } #[cfg(test)] mod tests { use std::task::{Context, Poll}; use futures::future::{lazy, ok, Ready}; use super::*; use crate::{pipeline, pipeline_factory, Service, ServiceFactory}; #[derive(Clone)] struct Srv; impl Service for Srv { type Request = (); type Response = (); type Error = (); type Future = Ready>; fn poll_ready(&mut self, _: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _: ()) -> Self::Future { ok(()) } } #[tokio::test] async fn test_call() { let mut srv = pipeline(apply_fn(Srv, |req: &'static str, srv| { let fut = srv.call(()); async move { let res = fut.await.unwrap(); Ok((req, res)) } })); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let res = srv.call("srv").await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv", ()))); } #[tokio::test] async fn test_new_service() { let new_srv = pipeline_factory(apply_fn_factory( || ok::<_, ()>(Srv), |req: &'static str, srv| { let fut = srv.call(()); async move { let res = fut.await.unwrap(); Ok((req, res)) } }, )); let mut srv = new_srv.new_service(&()).await.unwrap(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let res = srv.call("srv").await; assert!(res.is_ok()); assert_eq!(res.unwrap(), (("srv", ()))); } }