use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use pin_project::pin_project; use super::{Service, ServiceFactory}; /// Service for the `map` combinator, changing the type of a service's response. /// /// This is created by the `ServiceExt::map` method. pub(crate) struct Map { service: A, f: F, _t: PhantomData, } impl Map { /// Create new `Map` combinator pub fn new(service: A, f: F) -> Self where A: Service, F: FnMut(A::Response) -> Response, { Self { service, f, _t: PhantomData, } } } impl Clone for Map where A: Clone, F: Clone, { fn clone(&self) -> Self { Map { service: self.service.clone(), f: self.f.clone(), _t: PhantomData, } } } impl Service for Map where A: Service, F: FnMut(A::Response) -> Response + Clone, { type Request = A::Request; type Response = Response; type Error = A::Error; type Future = MapFuture; fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { self.service.poll_ready(ctx) } fn call(&mut self, req: A::Request) -> Self::Future { MapFuture::new(self.service.call(req), self.f.clone()) } } #[pin_project] pub(crate) struct MapFuture where A: Service, F: FnMut(A::Response) -> Response, { f: F, #[pin] fut: A::Future, } impl MapFuture where A: Service, F: FnMut(A::Response) -> Response, { fn new(fut: A::Future, f: F) -> Self { MapFuture { f, fut } } } impl Future for MapFuture where A: Service, F: FnMut(A::Response) -> Response, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); match this.fut.poll(cx) { Poll::Ready(Ok(resp)) => Poll::Ready(Ok((this.f)(resp))), Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } } } /// `MapNewService` new service combinator pub(crate) struct MapNewService { a: A, f: F, r: PhantomData, } impl MapNewService { /// Create new `Map` new service instance pub fn new(a: A, f: F) -> Self where A: ServiceFactory, F: FnMut(A::Response) -> Res, { Self { a, f, r: PhantomData, } } } impl Clone for MapNewService where A: Clone, F: Clone, { fn clone(&self) -> Self { Self { a: self.a.clone(), f: self.f.clone(), r: PhantomData, } } } impl ServiceFactory for MapNewService where A: ServiceFactory, F: FnMut(A::Response) -> Res + Clone, { type Request = A::Request; type Response = Res; type Error = A::Error; type Config = A::Config; type Service = Map; type InitError = A::InitError; type Future = MapNewServiceFuture; fn new_service(&self, cfg: &A::Config) -> Self::Future { MapNewServiceFuture::new(self.a.new_service(cfg), self.f.clone()) } } #[pin_project] pub(crate) struct MapNewServiceFuture where A: ServiceFactory, F: FnMut(A::Response) -> Res, { #[pin] fut: A::Future, f: Option, } impl MapNewServiceFuture where A: ServiceFactory, F: FnMut(A::Response) -> Res, { fn new(fut: A::Future, f: F) -> Self { MapNewServiceFuture { f: Some(f), fut } } } impl Future for MapNewServiceFuture where A: ServiceFactory, F: FnMut(A::Response) -> Res, { type Output = Result, A::InitError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if let Poll::Ready(svc) = this.fut.poll(cx)? { Poll::Ready(Ok(Map::new(svc, this.f.take().unwrap()))) } else { Poll::Pending } } } #[cfg(test)] mod tests { use futures::future::{lazy, ok, Ready}; use super::*; use crate::{into_factory, into_service, Service}; struct Srv; impl Service for Srv { type Request = (); type Response = (); type Error = (); type Future = Ready>; fn poll_ready(&mut self, _: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _: ()) -> Self::Future { ok(()) } } #[tokio::test] async fn test_poll_ready() { let mut srv = into_service(Srv).map(|_| "ok"); let res = lazy(|cx| srv.poll_ready(cx)).await; assert_eq!(res, Poll::Ready(Ok(()))); } #[tokio::test] async fn test_call() { let mut srv = into_service(Srv).map(|_| "ok"); let res = srv.call(()).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), "ok"); } #[tokio::test] async fn test_new_service() { let new_srv = into_factory(|| ok::<_, ()>(Srv)).map(|_| "ok"); let mut srv = new_srv.new_service(&()).await.unwrap(); let res = srv.call(()).await; assert!(res.is_ok()); assert_eq!(res.unwrap(), ("ok")); } }