1
0
mirror of https://github.com/fafhrd91/actix-net synced 2025-01-31 12:42:09 +01:00

refactor framed and stream dispatchers

This commit is contained in:
Nikolay Kim 2019-12-11 12:42:07 +06:00
parent 2e5e69c9ba
commit 631cb86947
3 changed files with 192 additions and 293 deletions

View File

@ -41,6 +41,7 @@ impl<T> Cell<T> {
unsafe { &mut *self.inner.as_ref().get() } unsafe { &mut *self.inner.as_ref().get() }
} }
#[allow(clippy::mut_from_ref)]
pub(crate) unsafe fn get_mut_unsafe(&self) -> &mut T { pub(crate) unsafe fn get_mut_unsafe(&self) -> &mut T {
&mut *self.inner.as_ref().get() &mut *self.inner.as_ref().get()
} }

View File

@ -1,37 +1,33 @@
//! Framed dispatcher service and related utilities //! Framed dispatcher service and related utilities
#![allow(type_alias_bounds)] #![allow(type_alias_bounds)]
use std::collections::VecDeque;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, mem}; use std::{fmt, mem};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use futures::future::{ready, FutureExt}; use futures::{Future, FutureExt, Stream};
use futures::{Future, Sink, Stream};
use log::debug; use log::debug;
use crate::cell::Cell;
use crate::mpsc; use crate::mpsc;
use crate::task::LocalWaker;
type Request<U> = <U as Decoder>::Item; type Request<U> = <U as Decoder>::Item;
type Response<U> = <U as Encoder>::Item; type Response<U> = <U as Encoder>::Item;
/// Framed transport errors /// Framed transport errors
pub enum FramedTransportError<E, U: Encoder + Decoder> { pub enum DispatcherError<E, U: Encoder + Decoder> {
Service(E), Service(E),
Encoder(<U as Encoder>::Error), Encoder(<U as Encoder>::Error),
Decoder(<U as Decoder>::Error), Decoder(<U as Decoder>::Error),
} }
impl<E, U: Encoder + Decoder> From<E> for FramedTransportError<E, U> { impl<E, U: Encoder + Decoder> From<E> for DispatcherError<E, U> {
fn from(err: E) -> Self { fn from(err: E) -> Self {
FramedTransportError::Service(err) DispatcherError::Service(err)
} }
} }
impl<E, U: Encoder + Decoder> fmt::Debug for FramedTransportError<E, U> impl<E, U: Encoder + Decoder> fmt::Debug for DispatcherError<E, U>
where where
E: fmt::Debug, E: fmt::Debug,
<U as Encoder>::Error: fmt::Debug, <U as Encoder>::Error: fmt::Debug,
@ -39,20 +35,14 @@ where
{ {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
FramedTransportError::Service(ref e) => { DispatcherError::Service(ref e) => write!(fmt, "DispatcherError::Service({:?})", e),
write!(fmt, "FramedTransportError::Service({:?})", e) DispatcherError::Encoder(ref e) => write!(fmt, "DispatcherError::Encoder({:?})", e),
} DispatcherError::Decoder(ref e) => write!(fmt, "DispatcherError::Decoder({:?})", e),
FramedTransportError::Encoder(ref e) => {
write!(fmt, "FramedTransportError::Encoder({:?})", e)
}
FramedTransportError::Decoder(ref e) => {
write!(fmt, "FramedTransportError::Encoder({:?})", e)
}
} }
} }
} }
impl<E, U: Encoder + Decoder> fmt::Display for FramedTransportError<E, U> impl<E, U: Encoder + Decoder> fmt::Display for DispatcherError<E, U>
where where
E: fmt::Display, E: fmt::Display,
<U as Encoder>::Error: fmt::Debug, <U as Encoder>::Error: fmt::Debug,
@ -60,25 +50,22 @@ where
{ {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
FramedTransportError::Service(ref e) => write!(fmt, "{}", e), DispatcherError::Service(ref e) => write!(fmt, "{}", e),
FramedTransportError::Encoder(ref e) => write!(fmt, "{:?}", e), DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e),
FramedTransportError::Decoder(ref e) => write!(fmt, "{:?}", e), DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e),
} }
} }
} }
pub enum FramedMessage<T> { pub enum Message<T> {
Message(T), Item(T),
Close, Close,
} }
type Rx<U> = Option<mpsc::Receiver<FramedMessage<<U as Encoder>::Item>>>;
type Inner<S: Service, U> = Cell<FramedTransportInner<<U as Encoder>::Item, S::Error>>;
/// FramedTransport - is a future that reads frames from Framed object /// FramedTransport - is a future that reads frames from Framed object
/// and pass then to the service. /// and pass then to the service.
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct FramedTransport<S, T, U> pub struct Dispatcher<S, T, U>
where where
S: Service<Request = Request<U>, Response = Response<U>>, S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static, S::Error: 'static,
@ -89,26 +76,37 @@ where
<U as Encoder>::Error: std::fmt::Debug, <U as Encoder>::Error: std::fmt::Debug,
{ {
service: S, service: S,
state: TransportState<S, U>, state: State<S, U>,
framed: Framed<T, U>, framed: Framed<T, U>,
rx: Option<mpsc::Receiver<FramedMessage<<U as Encoder>::Item>>>, rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
inner: Cell<FramedTransportInner<<U as Encoder>::Item, S::Error>>, tx: mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>>,
} }
enum TransportState<S: Service, U: Encoder + Decoder> { enum State<S: Service, U: Encoder + Decoder> {
Processing, Processing,
Error(FramedTransportError<S::Error, U>), Error(DispatcherError<S::Error, U>),
FramedError(FramedTransportError<S::Error, U>), FramedError(DispatcherError<S::Error, U>),
FlushAndStop, FlushAndStop,
Stopping, Stopping,
} }
struct FramedTransportInner<I, E> { impl<S: Service, U: Encoder + Decoder> State<S, U> {
buf: VecDeque<Result<I, E>>, fn take_error(&mut self) -> DispatcherError<S::Error, U> {
task: LocalWaker, match mem::replace(self, State::Processing) {
State::Error(err) => err,
_ => panic!(),
}
}
fn take_framed_error(&mut self) -> DispatcherError<S::Error, U> {
match mem::replace(self, State::Processing) {
State::FramedError(err) => err,
_ => panic!(),
}
}
} }
impl<S, T, U> FramedTransport<S, T, U> impl<S, T, U> Dispatcher<S, T, U>
where where
S: Service<Request = Request<U>, Response = Response<U>>, S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static, S::Error: 'static,
@ -119,25 +117,19 @@ where
<U as Encoder>::Error: std::fmt::Debug, <U as Encoder>::Error: std::fmt::Debug,
{ {
pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self { pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self {
FramedTransport { let (tx, rx) = mpsc::channel();
Dispatcher {
framed, framed,
rx: None, rx,
tx,
service: service.into_service(), service: service.into_service(),
state: TransportState::Processing, state: State::Processing,
inner: Cell::new(FramedTransportInner {
buf: VecDeque::new(),
task: LocalWaker::new(),
}),
} }
} }
/// Get Sender /// Get sink
pub fn set_receiver( pub fn get_sink(&self) -> mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>> {
mut self, self.tx.clone()
rx: mpsc::Receiver<FramedMessage<<U as Encoder>::Item>>,
) -> Self {
self.rx = Some(rx);
self
} }
/// Get reference to a service wrapped by `FramedTransport` instance. /// Get reference to a service wrapped by `FramedTransport` instance.
@ -162,9 +154,99 @@ where
pub fn get_framed_mut(&mut self) -> &mut Framed<T, U> { pub fn get_framed_mut(&mut self) -> &mut Framed<T, U> {
&mut self.framed &mut self.framed
} }
fn poll_read(&mut self, cx: &mut Context<'_>) -> bool
where
S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
{
loop {
match self.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let item = match self.framed.next_item(cx) {
Poll::Ready(Some(Ok(el))) => el,
Poll::Ready(Some(Err(err))) => {
self.state = State::FramedError(DispatcherError::Decoder(err));
return true;
}
Poll::Pending => return false,
Poll::Ready(None) => {
self.state = State::Stopping;
return true;
}
};
let tx = self.tx.clone();
actix_rt::spawn(self.service.call(item).map(move |item| {
let _ = tx.send(item.map(Message::Item));
}));
}
Poll::Pending => return false,
Poll::Ready(Err(err)) => {
self.state = State::Error(DispatcherError::Service(err));
return true;
}
}
}
}
/// write to framed object
fn poll_write(&mut self, cx: &mut Context<'_>) -> bool
where
S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
{
loop {
while !self.framed.is_write_buf_full() {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(Some(Ok(Message::Item(msg)))) => {
if let Err(err) = self.framed.write(msg) {
self.state = State::FramedError(DispatcherError::Encoder(err));
return true;
}
}
Poll::Ready(Some(Ok(Message::Close))) => {
self.state = State::FlushAndStop;
return true;
}
Poll::Ready(Some(Err(err))) => {
self.state = State::Error(DispatcherError::Service(err));
return true;
}
Poll::Ready(None) | Poll::Pending => break,
}
}
if !self.framed.is_write_buf_empty() {
match self.framed.flush(cx) {
Poll::Pending => break,
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => {
debug!("Error sending data: {:?}", err);
self.state = State::FramedError(DispatcherError::Encoder(err));
return true;
}
}
} else {
break;
}
}
false
}
} }
impl<S, T, U> Future for FramedTransport<S, T, U> impl<S, T, U> Future for Dispatcher<S, T, U>
where where
S: Service<Request = Request<U>, Response = Response<U>>, S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static, S::Error: 'static,
@ -175,210 +257,50 @@ where
<U as Encoder>::Error: std::fmt::Debug, <U as Encoder>::Error: std::fmt::Debug,
<U as Decoder>::Error: std::fmt::Debug, <U as Decoder>::Error: std::fmt::Debug,
{ {
type Output = Result<(), FramedTransportError<S::Error, U>>; type Output = Result<(), DispatcherError<S::Error, U>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.get_ref().task.register(cx.waker()); loop {
let this = self.as_mut().project();
let this = self.project(); return match this.state {
poll( State::Processing => {
cx, if self.poll_read(cx) || self.poll_write(cx) {
this.service, continue;
this.state, } else {
this.framed, Poll::Pending
this.rx, }
this.inner, }
) State::Error(_) => {
} // flush write buffer
} if !self.framed.is_write_buf_empty() {
match self.framed.flush(cx) {
fn poll<S, T, U>( Poll::Pending => Poll::Pending,
cx: &mut Context<'_>, Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => {
srv: &mut S, Poll::Ready(Err(self.state.take_error()))
state: &mut TransportState<S, U>, }
framed: &mut Framed<T, U>, }
rx: &mut Rx<U>, } else {
inner: &mut Inner<S, U>, Poll::Ready(Err(self.state.take_error()))
) -> Poll<Result<(), FramedTransportError<S::Error, U>>> }
where }
S: Service<Request = Request<U>, Response = Response<U>>, State::FlushAndStop => {
S::Error: 'static, if !this.framed.is_write_buf_empty() {
S::Future: 'static, match this.framed.flush(cx) {
T: AsyncRead + AsyncWrite, Poll::Ready(Err(err)) => {
U: Decoder + Encoder, debug!("Error sending data: {:?}", err);
<U as Encoder>::Item: 'static, Poll::Ready(Ok(()))
<U as Encoder>::Error: std::fmt::Debug, }
{ Poll::Pending => Poll::Pending,
match mem::replace(state, TransportState::Processing) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
TransportState::Processing => { }
if poll_read(cx, srv, state, framed, inner) } else {
|| poll_write(cx, state, framed, rx, inner)
{
poll(cx, srv, state, framed, rx, inner)
} else {
Poll::Pending
}
}
TransportState::Error(err) => {
let is_empty = framed.is_write_buf_empty();
if is_empty || poll_write(cx, state, framed, rx, inner) {
Poll::Ready(Err(err))
} else {
*state = TransportState::Error(err);
Poll::Pending
}
}
TransportState::FlushAndStop => {
if !framed.is_write_buf_empty() {
match Pin::new(framed).poll_flush(cx) {
Poll::Ready(Err(err)) => {
debug!("Error sending data: {:?}", err);
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
} }
} else { State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())),
Poll::Ready(Ok(())) State::Stopping => Poll::Ready(Ok(())),
} };
}
TransportState::FramedError(err) => Poll::Ready(Err(err)),
TransportState::Stopping => Poll::Ready(Ok(())),
}
}
fn poll_read<S, T, U>(
cx: &mut Context<'_>,
srv: &mut S,
state: &mut TransportState<S, U>,
framed: &mut Framed<T, U>,
inner: &mut Inner<S, U>,
) -> bool
where
S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
{
loop {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let item = match framed.next_item(cx) {
Poll::Ready(Some(Ok(el))) => el,
Poll::Ready(Some(Err(err))) => {
*state =
TransportState::FramedError(FramedTransportError::Decoder(err));
return true;
}
Poll::Pending => return false,
Poll::Ready(None) => {
*state = TransportState::Stopping;
return true;
}
};
let mut cell = inner.clone();
let fut = srv.call(item).then(move |item| {
let inner = cell.get_mut();
inner.buf.push_back(item);
inner.task.wake();
ready(())
});
actix_rt::spawn(fut);
}
Poll::Pending => return false,
Poll::Ready(Err(err)) => {
*state = TransportState::Error(FramedTransportError::Service(err));
return true;
}
} }
} }
} }
/// write to framed object
fn poll_write<S, T, U>(
cx: &mut Context<'_>,
state: &mut TransportState<S, U>,
framed: &mut Framed<T, U>,
rx: &mut Rx<U>,
inner: &mut Inner<S, U>,
) -> bool
where
S: Service<Request = Request<U>, Response = Response<U>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
{
// let this = self.project();
let inner = inner.get_mut();
let mut rx_done = rx.is_none();
let mut buf_empty = inner.buf.is_empty();
loop {
while !framed.is_write_buf_full() {
if !buf_empty {
match inner.buf.pop_front().unwrap() {
Ok(msg) => {
if let Err(err) = framed.write(msg) {
*state =
TransportState::FramedError(FramedTransportError::Encoder(err));
return true;
}
buf_empty = inner.buf.is_empty();
}
Err(err) => {
*state = TransportState::Error(FramedTransportError::Service(err));
return true;
}
}
}
if !rx_done && rx.is_some() {
match Pin::new(rx.as_mut().unwrap()).poll_next(cx) {
Poll::Ready(Some(FramedMessage::Message(msg))) => {
if let Err(err) = framed.write(msg) {
*state =
TransportState::FramedError(FramedTransportError::Encoder(err));
return true;
}
}
Poll::Ready(Some(FramedMessage::Close)) => {
*state = TransportState::FlushAndStop;
return true;
}
Poll::Ready(None) => {
rx_done = true;
let _ = rx.take();
}
Poll::Pending => rx_done = true,
}
}
if rx_done && buf_empty {
break;
}
}
if !framed.is_write_buf_empty() {
match framed.flush(cx) {
Poll::Pending => break,
Poll::Ready(Err(err)) => {
debug!("Error sending data: {:?}", err);
*state = TransportState::FramedError(FramedTransportError::Encoder(err));
return true;
}
Poll::Ready(Ok(_)) => (),
}
} else {
break;
}
}
false
}

View File

@ -3,12 +3,12 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use futures::Stream; use futures::{FutureExt, Stream};
use crate::mpsc; use crate::mpsc;
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct StreamDispatcher<S, T> pub struct Dispatcher<S, T>
where where
S: Stream, S: Stream,
T: Service<Request = S::Item, Response = ()> + 'static, T: Service<Request = S::Item, Response = ()> + 'static,
@ -20,7 +20,7 @@ where
err_tx: mpsc::Sender<T::Error>, err_tx: mpsc::Sender<T::Error>,
} }
impl<S, T> StreamDispatcher<S, T> impl<S, T> Dispatcher<S, T>
where where
S: Stream, S: Stream,
T: Service<Request = S::Item, Response = ()> + 'static, T: Service<Request = S::Item, Response = ()> + 'static,
@ -30,7 +30,7 @@ where
F: IntoService<T>, F: IntoService<T>,
{ {
let (err_tx, err_rx) = mpsc::channel(); let (err_tx, err_rx) = mpsc::channel();
StreamDispatcher { Dispatcher {
err_rx, err_rx,
err_tx, err_tx,
stream, stream,
@ -39,7 +39,7 @@ where
} }
} }
impl<S, T> Future for StreamDispatcher<S, T> impl<S, T> Future for Dispatcher<S, T>
where where
S: Stream, S: Stream,
T: Service<Request = S::Item, Response = ()> + 'static, T: Service<Request = S::Item, Response = ()> + 'static,
@ -54,47 +54,23 @@ where
} }
loop { loop {
match this.service.poll_ready(cx)? { return match this.service.poll_ready(cx)? {
Poll::Ready(_) => match this.stream.poll_next(cx) { Poll::Ready(_) => match this.stream.poll_next(cx) {
Poll::Ready(Some(item)) => { Poll::Ready(Some(item)) => {
actix_rt::spawn(StreamDispatcherService { let stop = this.err_tx.clone();
fut: this.service.call(item), actix_rt::spawn(this.service.call(item).map(move |res| {
stop: self.err_tx.clone(), if let Err(e) = res {
}); let _ = stop.send(e);
}
}));
this = self.as_mut().project(); this = self.as_mut().project();
continue;
} }
Poll::Pending => return Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Ready(None) => Poll::Ready(Ok(())),
}, },
Poll::Pending => return Poll::Pending, Poll::Pending => Poll::Pending,
} };
}
}
}
#[pin_project::pin_project]
struct StreamDispatcherService<F: Future, E> {
#[pin]
fut: F,
stop: mpsc::Sender<E>,
}
impl<F, E> Future for StreamDispatcherService<F, E>
where
F: Future<Output = Result<(), E>>,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let _ = this.stop.send(e);
Poll::Ready(())
}
} }
} }
} }