//! Framed dispatcher service and related utilities #![allow(type_alias_bounds)] use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, mem}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_service::{IntoService, Service}; use futures_util::{future::Future, stream::Stream, FutureExt}; use log::debug; use crate::mpsc; /// Framed transport errors pub enum DispatcherError + Decoder, I> { Service(E), Encoder(>::Error), Decoder(::Error), } impl + Decoder, I> From for DispatcherError { fn from(err: E) -> Self { DispatcherError::Service(err) } } impl + Decoder, I> fmt::Debug for DispatcherError where E: fmt::Debug, >::Error: fmt::Debug, ::Error: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { DispatcherError::Service(ref e) => write!(fmt, "DispatcherError::Service({:?})", e), DispatcherError::Encoder(ref e) => write!(fmt, "DispatcherError::Encoder({:?})", e), DispatcherError::Decoder(ref e) => write!(fmt, "DispatcherError::Decoder({:?})", e), } } } impl + Decoder, I> fmt::Display for DispatcherError where E: fmt::Display, >::Error: fmt::Debug, ::Error: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { DispatcherError::Service(ref e) => write!(fmt, "{}", e), DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e), DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e), } } } pub enum Message { Item(T), Close, } /// Dispatcher is a future that reads frames from Framed object /// and passes them to the service. #[pin_project::pin_project] pub struct Dispatcher where S: Service::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Encoder + Decoder, I: 'static, >::Error: std::fmt::Debug, { service: S, state: State, #[pin] framed: Framed, rx: mpsc::Receiver, S::Error>>, tx: mpsc::Sender, S::Error>>, } enum State + Decoder, I> { Processing, Error(DispatcherError), FramedError(DispatcherError), FlushAndStop, Stopping, } impl + Decoder, I> State { fn take_error(&mut self) -> DispatcherError { match mem::replace(self, State::Processing) { State::Error(err) => err, _ => panic!(), } } fn take_framed_error(&mut self) -> DispatcherError { match mem::replace(self, State::Processing) { State::FramedError(err) => err, _ => panic!(), } } } impl Dispatcher where S: Service::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, ::Error: std::fmt::Debug, >::Error: std::fmt::Debug, { pub fn new>(framed: Framed, service: F) -> Self { let (tx, rx) = mpsc::channel(); Dispatcher { framed, rx, tx, service: service.into_service(), state: State::Processing, } } /// Construct new `Dispatcher` instance with customer `mpsc::Receiver` pub fn with_rx>( framed: Framed, service: F, rx: mpsc::Receiver, S::Error>>, ) -> Self { let tx = rx.sender(); Dispatcher { framed, rx, tx, service: service.into_service(), state: State::Processing, } } /// Get sink pub fn get_sink(&self) -> mpsc::Sender, S::Error>> { self.tx.clone() } /// Get reference to a service wrapped by `Dispatcher` instance. pub fn get_ref(&self) -> &S { &self.service } /// Get mutable reference to a service wrapped by `Dispatcher` instance. pub fn get_mut(&mut self) -> &mut S { &mut self.service } /// Get reference to a framed instance wrapped by `Dispatcher` /// instance. pub fn get_framed(&self) -> &Framed { &self.framed } /// Get mutable reference to a framed instance wrapped by `Dispatcher` instance. pub fn get_framed_mut(&mut self) -> &mut Framed { &mut self.framed } fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool where S: Service::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: std::fmt::Debug, { loop { let this = self.as_mut().project(); match this.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { let item = match this.framed.next_item(cx) { Poll::Ready(Some(Ok(el))) => el, Poll::Ready(Some(Err(err))) => { *this.state = State::FramedError(DispatcherError::Decoder(err)); return true; } Poll::Pending => return false, Poll::Ready(None) => { *this.state = State::Stopping; return true; } }; let tx = this.tx.clone(); actix_rt::spawn(this.service.call(item).map(move |item| { let _ = tx.send(item.map(Message::Item)); })); } Poll::Pending => return false, Poll::Ready(Err(err)) => { *this.state = State::Error(DispatcherError::Service(err)); return true; } } } } /// write to framed object fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool where S: Service::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: std::fmt::Debug, { loop { let mut this = self.as_mut().project(); while !this.framed.is_write_buf_full() { match Pin::new(&mut this.rx).poll_next(cx) { Poll::Ready(Some(Ok(Message::Item(msg)))) => { if let Err(err) = this.framed.as_mut().write(msg) { *this.state = State::FramedError(DispatcherError::Encoder(err)); return true; } } Poll::Ready(Some(Ok(Message::Close))) => { *this.state = State::FlushAndStop; return true; } Poll::Ready(Some(Err(err))) => { *this.state = State::Error(DispatcherError::Service(err)); return true; } Poll::Ready(None) | Poll::Pending => break, } } if !this.framed.is_write_buf_empty() { match this.framed.flush(cx) { Poll::Pending => break, Poll::Ready(Ok(_)) => (), Poll::Ready(Err(err)) => { debug!("Error sending data: {:?}", err); *this.state = State::FramedError(DispatcherError::Encoder(err)); return true; } } } else { break; } } false } } impl Future for Dispatcher where S: Service::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: std::fmt::Debug, ::Error: std::fmt::Debug, { type Output = Result<(), DispatcherError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let this = self.as_mut().project(); return match this.state { State::Processing => { if self.as_mut().poll_read(cx) || self.as_mut().poll_write(cx) { continue; } else { Poll::Pending } } State::Error(_) => { // flush write buffer if !this.framed.is_write_buf_empty() { if let Poll::Pending = this.framed.flush(cx) { return Poll::Pending; } } Poll::Ready(Err(this.state.take_error())) } State::FlushAndStop => { if !this.framed.is_write_buf_empty() { match this.framed.flush(cx) { Poll::Ready(Err(err)) => { debug!("Error sending data: {:?}", err); Poll::Ready(Ok(())) } Poll::Pending => Poll::Pending, Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), } } else { Poll::Ready(Ok(())) } } State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())), State::Stopping => Poll::Ready(Ok(())), }; } } }