diff --git a/actix-utils/src/cell.rs b/actix-utils/src/cell.rs index 716982ae..ee35125e 100644 --- a/actix-utils/src/cell.rs +++ b/actix-utils/src/cell.rs @@ -41,6 +41,7 @@ impl Cell { unsafe { &mut *self.inner.as_ref().get() } } + #[allow(clippy::mut_from_ref)] pub(crate) unsafe fn get_mut_unsafe(&self) -> &mut T { &mut *self.inner.as_ref().get() } diff --git a/actix-utils/src/framed.rs b/actix-utils/src/framed.rs index ce2ac148..c924cbe1 100644 --- a/actix-utils/src/framed.rs +++ b/actix-utils/src/framed.rs @@ -1,37 +1,33 @@ //! Framed dispatcher service and related utilities #![allow(type_alias_bounds)] -use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, mem}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_service::{IntoService, Service}; -use futures::future::{ready, FutureExt}; -use futures::{Future, Sink, Stream}; +use futures::{Future, FutureExt, Stream}; use log::debug; -use crate::cell::Cell; use crate::mpsc; -use crate::task::LocalWaker; type Request = ::Item; type Response = ::Item; /// Framed transport errors -pub enum FramedTransportError { +pub enum DispatcherError { Service(E), Encoder(::Error), Decoder(::Error), } -impl From for FramedTransportError { +impl From for DispatcherError { fn from(err: E) -> Self { - FramedTransportError::Service(err) + DispatcherError::Service(err) } } -impl fmt::Debug for FramedTransportError +impl fmt::Debug for DispatcherError where E: fmt::Debug, ::Error: fmt::Debug, @@ -39,20 +35,14 @@ where { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - FramedTransportError::Service(ref e) => { - write!(fmt, "FramedTransportError::Service({:?})", e) - } - FramedTransportError::Encoder(ref e) => { - write!(fmt, "FramedTransportError::Encoder({:?})", e) - } - FramedTransportError::Decoder(ref e) => { - write!(fmt, "FramedTransportError::Encoder({:?})", e) - } + DispatcherError::Service(ref e) => write!(fmt, "DispatcherError::Service({:?})", e), + DispatcherError::Encoder(ref e) => write!(fmt, "DispatcherError::Encoder({:?})", e), + DispatcherError::Decoder(ref e) => write!(fmt, "DispatcherError::Decoder({:?})", e), } } } -impl fmt::Display for FramedTransportError +impl fmt::Display for DispatcherError where E: fmt::Display, ::Error: fmt::Debug, @@ -60,25 +50,22 @@ where { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - FramedTransportError::Service(ref e) => write!(fmt, "{}", e), - FramedTransportError::Encoder(ref e) => write!(fmt, "{:?}", e), - FramedTransportError::Decoder(ref e) => write!(fmt, "{:?}", e), + DispatcherError::Service(ref e) => write!(fmt, "{}", e), + DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e), + DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e), } } } -pub enum FramedMessage { - Message(T), +pub enum Message { + Item(T), Close, } -type Rx = Option::Item>>>; -type Inner = Cell::Item, S::Error>>; - /// FramedTransport - is a future that reads frames from Framed object /// and pass then to the service. #[pin_project::pin_project] -pub struct FramedTransport +pub struct Dispatcher where S: Service, Response = Response>, S::Error: 'static, @@ -89,26 +76,37 @@ where ::Error: std::fmt::Debug, { service: S, - state: TransportState, + state: State, framed: Framed, - rx: Option::Item>>>, - inner: Cell::Item, S::Error>>, + rx: mpsc::Receiver::Item>, S::Error>>, + tx: mpsc::Sender::Item>, S::Error>>, } -enum TransportState { +enum State { Processing, - Error(FramedTransportError), - FramedError(FramedTransportError), + Error(DispatcherError), + FramedError(DispatcherError), FlushAndStop, Stopping, } -struct FramedTransportInner { - buf: VecDeque>, - task: LocalWaker, +impl State { + fn take_error(&mut self) -> DispatcherError { + match mem::replace(self, State::Processing) { + State::Error(err) => err, + _ => panic!(), + } + } + + fn take_framed_error(&mut self) -> DispatcherError { + match mem::replace(self, State::Processing) { + State::FramedError(err) => err, + _ => panic!(), + } + } } -impl FramedTransport +impl Dispatcher where S: Service, Response = Response>, S::Error: 'static, @@ -119,25 +117,19 @@ where ::Error: std::fmt::Debug, { pub fn new>(framed: Framed, service: F) -> Self { - FramedTransport { + let (tx, rx) = mpsc::channel(); + Dispatcher { framed, - rx: None, + rx, + tx, service: service.into_service(), - state: TransportState::Processing, - inner: Cell::new(FramedTransportInner { - buf: VecDeque::new(), - task: LocalWaker::new(), - }), + state: State::Processing, } } - /// Get Sender - pub fn set_receiver( - mut self, - rx: mpsc::Receiver::Item>>, - ) -> Self { - self.rx = Some(rx); - self + /// Get sink + pub fn get_sink(&self) -> mpsc::Sender::Item>, S::Error>> { + self.tx.clone() } /// Get reference to a service wrapped by `FramedTransport` instance. @@ -162,9 +154,99 @@ where pub fn get_framed_mut(&mut self) -> &mut Framed { &mut self.framed } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> bool + where + S: Service, Response = Response>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + ::Item: 'static, + ::Error: std::fmt::Debug, + { + loop { + match self.service.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + let item = match self.framed.next_item(cx) { + Poll::Ready(Some(Ok(el))) => el, + Poll::Ready(Some(Err(err))) => { + self.state = State::FramedError(DispatcherError::Decoder(err)); + return true; + } + Poll::Pending => return false, + Poll::Ready(None) => { + self.state = State::Stopping; + return true; + } + }; + + let tx = self.tx.clone(); + actix_rt::spawn(self.service.call(item).map(move |item| { + let _ = tx.send(item.map(Message::Item)); + })); + } + Poll::Pending => return false, + Poll::Ready(Err(err)) => { + self.state = State::Error(DispatcherError::Service(err)); + return true; + } + } + } + } + + /// write to framed object + fn poll_write(&mut self, cx: &mut Context<'_>) -> bool + where + S: Service, Response = Response>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + ::Item: 'static, + ::Error: std::fmt::Debug, + { + loop { + while !self.framed.is_write_buf_full() { + match Pin::new(&mut self.rx).poll_next(cx) { + Poll::Ready(Some(Ok(Message::Item(msg)))) => { + if let Err(err) = self.framed.write(msg) { + self.state = State::FramedError(DispatcherError::Encoder(err)); + return true; + } + } + Poll::Ready(Some(Ok(Message::Close))) => { + self.state = State::FlushAndStop; + return true; + } + Poll::Ready(Some(Err(err))) => { + self.state = State::Error(DispatcherError::Service(err)); + return true; + } + Poll::Ready(None) | Poll::Pending => break, + } + } + + if !self.framed.is_write_buf_empty() { + match self.framed.flush(cx) { + Poll::Pending => break, + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => { + debug!("Error sending data: {:?}", err); + self.state = State::FramedError(DispatcherError::Encoder(err)); + return true; + } + } + } else { + break; + } + } + + false + } } -impl Future for FramedTransport +impl Future for Dispatcher where S: Service, Response = Response>, S::Error: 'static, @@ -175,210 +257,50 @@ where ::Error: std::fmt::Debug, ::Error: std::fmt::Debug, { - type Output = Result<(), FramedTransportError>; + type Output = Result<(), DispatcherError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.inner.get_ref().task.register(cx.waker()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let this = self.as_mut().project(); - let this = self.project(); - poll( - cx, - this.service, - this.state, - this.framed, - this.rx, - this.inner, - ) - } -} - -fn poll( - cx: &mut Context<'_>, - srv: &mut S, - state: &mut TransportState, - framed: &mut Framed, - rx: &mut Rx, - inner: &mut Inner, -) -> Poll>> -where - S: Service, Response = Response>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - match mem::replace(state, TransportState::Processing) { - TransportState::Processing => { - if poll_read(cx, srv, state, framed, inner) - || poll_write(cx, state, framed, rx, inner) - { - poll(cx, srv, state, framed, rx, inner) - } else { - Poll::Pending - } - } - TransportState::Error(err) => { - let is_empty = framed.is_write_buf_empty(); - if is_empty || poll_write(cx, state, framed, rx, inner) { - Poll::Ready(Err(err)) - } else { - *state = TransportState::Error(err); - Poll::Pending - } - } - TransportState::FlushAndStop => { - if !framed.is_write_buf_empty() { - match Pin::new(framed).poll_flush(cx) { - Poll::Ready(Err(err)) => { - debug!("Error sending data: {:?}", err); + return match this.state { + State::Processing => { + if self.poll_read(cx) || self.poll_write(cx) { + continue; + } else { + Poll::Pending + } + } + State::Error(_) => { + // flush write buffer + if !self.framed.is_write_buf_empty() { + match self.framed.flush(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => { + Poll::Ready(Err(self.state.take_error())) + } + } + } else { + Poll::Ready(Err(self.state.take_error())) + } + } + State::FlushAndStop => { + if !this.framed.is_write_buf_empty() { + match this.framed.flush(cx) { + Poll::Ready(Err(err)) => { + debug!("Error sending data: {:?}", err); + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + } + } else { Poll::Ready(Ok(())) } - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), } - } else { - Poll::Ready(Ok(())) - } - } - TransportState::FramedError(err) => Poll::Ready(Err(err)), - TransportState::Stopping => Poll::Ready(Ok(())), - } -} - -fn poll_read( - cx: &mut Context<'_>, - srv: &mut S, - state: &mut TransportState, - framed: &mut Framed, - inner: &mut Inner, -) -> bool -where - S: Service, Response = Response>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - loop { - match srv.poll_ready(cx) { - Poll::Ready(Ok(_)) => { - let item = match framed.next_item(cx) { - Poll::Ready(Some(Ok(el))) => el, - Poll::Ready(Some(Err(err))) => { - *state = - TransportState::FramedError(FramedTransportError::Decoder(err)); - return true; - } - Poll::Pending => return false, - Poll::Ready(None) => { - *state = TransportState::Stopping; - return true; - } - }; - - let mut cell = inner.clone(); - let fut = srv.call(item).then(move |item| { - let inner = cell.get_mut(); - inner.buf.push_back(item); - inner.task.wake(); - ready(()) - }); - actix_rt::spawn(fut); - } - Poll::Pending => return false, - Poll::Ready(Err(err)) => { - *state = TransportState::Error(FramedTransportError::Service(err)); - return true; - } + State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())), + State::Stopping => Poll::Ready(Ok(())), + }; } } } - -/// write to framed object -fn poll_write( - cx: &mut Context<'_>, - state: &mut TransportState, - framed: &mut Framed, - rx: &mut Rx, - inner: &mut Inner, -) -> bool -where - S: Service, Response = Response>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - // let this = self.project(); - - let inner = inner.get_mut(); - let mut rx_done = rx.is_none(); - let mut buf_empty = inner.buf.is_empty(); - loop { - while !framed.is_write_buf_full() { - if !buf_empty { - match inner.buf.pop_front().unwrap() { - Ok(msg) => { - if let Err(err) = framed.write(msg) { - *state = - TransportState::FramedError(FramedTransportError::Encoder(err)); - return true; - } - buf_empty = inner.buf.is_empty(); - } - Err(err) => { - *state = TransportState::Error(FramedTransportError::Service(err)); - return true; - } - } - } - - if !rx_done && rx.is_some() { - match Pin::new(rx.as_mut().unwrap()).poll_next(cx) { - Poll::Ready(Some(FramedMessage::Message(msg))) => { - if let Err(err) = framed.write(msg) { - *state = - TransportState::FramedError(FramedTransportError::Encoder(err)); - return true; - } - } - Poll::Ready(Some(FramedMessage::Close)) => { - *state = TransportState::FlushAndStop; - return true; - } - Poll::Ready(None) => { - rx_done = true; - let _ = rx.take(); - } - Poll::Pending => rx_done = true, - } - } - - if rx_done && buf_empty { - break; - } - } - - if !framed.is_write_buf_empty() { - match framed.flush(cx) { - Poll::Pending => break, - Poll::Ready(Err(err)) => { - debug!("Error sending data: {:?}", err); - *state = TransportState::FramedError(FramedTransportError::Encoder(err)); - return true; - } - Poll::Ready(Ok(_)) => (), - } - } else { - break; - } - } - - false -} diff --git a/actix-utils/src/stream.rs b/actix-utils/src/stream.rs index 866bbf68..b7f008e4 100644 --- a/actix-utils/src/stream.rs +++ b/actix-utils/src/stream.rs @@ -3,12 +3,12 @@ use std::pin::Pin; use std::task::{Context, Poll}; use actix_service::{IntoService, Service}; -use futures::Stream; +use futures::{FutureExt, Stream}; use crate::mpsc; #[pin_project::pin_project] -pub struct StreamDispatcher +pub struct Dispatcher where S: Stream, T: Service + 'static, @@ -20,7 +20,7 @@ where err_tx: mpsc::Sender, } -impl StreamDispatcher +impl Dispatcher where S: Stream, T: Service + 'static, @@ -30,7 +30,7 @@ where F: IntoService, { let (err_tx, err_rx) = mpsc::channel(); - StreamDispatcher { + Dispatcher { err_rx, err_tx, stream, @@ -39,7 +39,7 @@ where } } -impl Future for StreamDispatcher +impl Future for Dispatcher where S: Stream, T: Service + 'static, @@ -54,47 +54,23 @@ where } loop { - match this.service.poll_ready(cx)? { + return match this.service.poll_ready(cx)? { Poll::Ready(_) => match this.stream.poll_next(cx) { Poll::Ready(Some(item)) => { - actix_rt::spawn(StreamDispatcherService { - fut: this.service.call(item), - stop: self.err_tx.clone(), - }); + let stop = this.err_tx.clone(); + actix_rt::spawn(this.service.call(item).map(move |res| { + if let Err(e) = res { + let _ = stop.send(e); + } + })); this = self.as_mut().project(); + continue; } - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(Ok(())), }, - Poll::Pending => return Poll::Pending, - } - } - } -} - -#[pin_project::pin_project] -struct StreamDispatcherService { - #[pin] - fut: F, - stop: mpsc::Sender, -} - -impl Future for StreamDispatcherService -where - F: Future>, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - match this.fut.poll(cx) { - Poll::Ready(Ok(_)) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let _ = this.stop.send(e); - Poll::Ready(()) - } + Poll::Pending => Poll::Pending, + }; } } }