diff --git a/actix-ioframe/src/cell.rs b/actix-ioframe/src/cell.rs deleted file mode 100644 index c14ffefc..00000000 --- a/actix-ioframe/src/cell.rs +++ /dev/null @@ -1,39 +0,0 @@ -//! Custom cell impl - -use std::cell::UnsafeCell; -use std::fmt; -use std::rc::Rc; - -pub(crate) struct Cell { - inner: Rc>, -} - -impl Clone for Cell { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } -} - -impl fmt::Debug for Cell { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.inner.fmt(f) - } -} - -impl Cell { - pub fn new(inner: T) -> Self { - Self { - inner: Rc::new(UnsafeCell::new(inner)), - } - } - - pub(crate) unsafe fn get_ref(&mut self) -> &T { - &*self.inner.as_ref().get() - } - - pub(crate) unsafe fn get_mut(&mut self) -> &mut T { - &mut *self.inner.as_ref().get() - } -} diff --git a/actix-ioframe/src/connect.rs b/actix-ioframe/src/connect.rs index 978c7e57..f7534413 100644 --- a/actix-ioframe/src/connect.rs +++ b/actix-ioframe/src/connect.rs @@ -3,55 +3,51 @@ use std::pin::Pin; use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; -use actix_utils::mpsc; use futures::Stream; -use crate::dispatcher::FramedMessage; use crate::sink::Sink; -pub struct Connect { +pub struct Connect +where + Codec: Encoder + Decoder, +{ io: Io, + sink: Sink<::Item, Err>, _t: PhantomData<(St, Codec)>, } -impl Connect +impl Connect where Io: AsyncRead + AsyncWrite, + Codec: Encoder + Decoder, { - pub(crate) fn new(io: Io) -> Self { + pub(crate) fn new(io: Io, sink: Sink<::Item, Err>) -> Self { Self { io, + sink, _t: PhantomData, } } - pub fn codec(self, codec: Codec) -> ConnectResult - where - Codec: Encoder + Decoder, - { - let (tx, rx) = mpsc::channel(); - let sink = Sink::new(tx); - + pub fn codec(self, codec: Codec) -> ConnectResult { ConnectResult { state: (), + sink: self.sink, framed: Framed::new(self.io, codec), - rx, - sink, } } } #[pin_project::pin_project] -pub struct ConnectResult { +pub struct ConnectResult { pub(crate) state: St, pub(crate) framed: Framed, - pub(crate) rx: mpsc::Receiver::Item>>, - pub(crate) sink: Sink<::Item>, + pub(crate) sink: Sink<::Item, Err>, } -impl ConnectResult { +impl ConnectResult { #[inline] - pub fn sink(&self) -> &Sink<::Item> { + pub fn sink(&self) -> &Sink<::Item, Err> { &self.sink } @@ -66,17 +62,16 @@ impl ConnectResult { } #[inline] - pub fn state(self, state: S) -> ConnectResult { + pub fn state(self, state: S) -> ConnectResult { ConnectResult { state, framed: self.framed, - rx: self.rx, sink: self.sink, } } } -impl Stream for ConnectResult +impl Stream for ConnectResult where Io: AsyncRead + AsyncWrite, Codec: Encoder + Decoder, @@ -88,7 +83,8 @@ where } } -impl futures::Sink<::Item> for ConnectResult +impl futures::Sink<::Item> + for ConnectResult where Io: AsyncRead + AsyncWrite, Codec: Encoder + Decoder, diff --git a/actix-ioframe/src/dispatcher.rs b/actix-ioframe/src/dispatcher.rs index f01a7721..8db9a436 100644 --- a/actix-ioframe/src/dispatcher.rs +++ b/actix-ioframe/src/dispatcher.rs @@ -1,39 +1,34 @@ //! Framed dispatcher service and related utilities -use std::collections::VecDeque; -use std::mem; use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_service::{IntoService, Service}; -use actix_utils::task::LocalWaker; use actix_utils::{mpsc, oneshot}; -use futures::future::ready; -use futures::{FutureExt, Sink as FutureSink, Stream}; +use futures::{FutureExt, Stream}; use log::debug; -use crate::cell::Cell; use crate::error::ServiceError; use crate::item::Item; use crate::sink::Sink; -type Request = Item; +type Request = Item; type Response = ::Item; -pub(crate) enum FramedMessage { - Message(T), - Close, +pub(crate) enum Message { + Item(T), WaitClose(oneshot::Sender<()>), + Close, } /// FramedTransport - is a future that reads frames from Framed object /// and pass then to the service. #[pin_project::pin_project] -pub(crate) struct FramedDispatcher +pub(crate) struct Dispatcher where St: Clone, - S: Service, Response = Option>>, + S: Service, Response = Option>, Error = E>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, @@ -42,19 +37,19 @@ where ::Error: std::fmt::Debug, { service: S, - sink: Sink<::Item>, + sink: Sink<::Item, E>, state: St, dispatch_state: FramedState, framed: Framed, - rx: Option::Item>>>, - inner: Cell::Item, S::Error>>, + rx: mpsc::Receiver::Item>, E>>, + tx: mpsc::Sender::Item>, E>>, disconnect: Option>, } -impl FramedDispatcher +impl Dispatcher where St: Clone, - S: Service, Response = Option>>, + S: Service, Response = Option>, Error = E>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, @@ -66,22 +61,21 @@ where framed: Framed, state: St, service: F, - rx: mpsc::Receiver::Item>>, - sink: Sink<::Item>, + sink: Sink<::Item, E>, + rx: mpsc::Receiver::Item>, E>>, disconnect: Option>, ) -> Self { - FramedDispatcher { + let tx = rx.sender(); + + Dispatcher { framed, state, sink, disconnect, - rx: Some(rx), + rx, + tx, service: service.into_service(), dispatch_state: FramedState::Processing, - inner: Cell::new(FramedDispatcherInner { - buf: VecDeque::new(), - task: LocalWaker::new(), - }), } } } @@ -116,17 +110,26 @@ impl FramedState { } } } + + fn take_error(&mut self) -> ServiceError { + match std::mem::replace(self, FramedState::Processing) { + FramedState::Error(err) => err, + _ => panic!(), + } + } + + fn take_framed_error(&mut self) -> ServiceError { + match std::mem::replace(self, FramedState::Processing) { + FramedState::FramedError(err) => err, + _ => panic!(), + } + } } -struct FramedDispatcherInner { - buf: VecDeque>, - task: LocalWaker, -} - -impl FramedDispatcher +impl Dispatcher where St: Clone, - S: Service, Response = Option>>, + S: Service, Response = Option>, Error = E>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, @@ -134,263 +137,150 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { + fn poll_read(&mut self, cx: &mut Context<'_>) -> bool { + loop { + match self.service.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + let item = match self.framed.next_item(cx) { + Poll::Ready(Some(Ok(el))) => el, + Poll::Ready(Some(Err(err))) => { + self.dispatch_state = + FramedState::FramedError(ServiceError::Decoder(err)); + return true; + } + Poll::Pending => return false, + Poll::Ready(None) => { + log::trace!("Client disconnected"); + self.dispatch_state = FramedState::Stopping; + return true; + } + }; + + let tx = self.tx.clone(); + actix_rt::spawn( + self.service + .call(Item::new(self.state.clone(), self.sink.clone(), item)) + .map(move |item| { + let item = match item { + Ok(Some(item)) => Ok(Message::Item(item)), + Ok(None) => return, + Err(err) => Err(err), + }; + let _ = tx.send(item); + }), + ); + } + Poll::Pending => return false, + Poll::Ready(Err(err)) => { + self.dispatch_state = FramedState::Error(ServiceError::Service(err)); + return true; + } + } + } + } + + /// write to framed object + fn poll_write(&mut self, cx: &mut Context<'_>) -> bool { + loop { + while !self.framed.is_write_buf_full() { + match Pin::new(&mut self.rx).poll_next(cx) { + Poll::Ready(Some(Ok(Message::Item(msg)))) => { + if let Err(err) = self.framed.write(msg) { + self.dispatch_state = + FramedState::FramedError(ServiceError::Encoder(err)); + return true; + } + } + Poll::Ready(Some(Ok(Message::Close))) => { + self.dispatch_state.stop(None); + return true; + } + Poll::Ready(Some(Ok(Message::WaitClose(tx)))) => { + self.dispatch_state.stop(Some(tx)); + return true; + } + Poll::Ready(Some(Err(err))) => { + self.dispatch_state = FramedState::Error(ServiceError::Service(err)); + return true; + } + Poll::Ready(None) | Poll::Pending => break, + } + } + + if !self.framed.is_write_buf_empty() { + match self.framed.flush(cx) { + Poll::Pending => break, + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => { + debug!("Error sending data: {:?}", err); + self.dispatch_state = + FramedState::FramedError(ServiceError::Encoder(err)); + return true; + } + } + } else { + break; + } + } + false + } + pub(crate) fn poll( &mut self, cx: &mut Context<'_>, ) -> Poll>> { - let this = self; - unsafe { this.inner.get_ref().task.register(cx.waker()) }; - - poll( - cx, - &mut this.service, - &mut this.state, - &mut this.sink, - &mut this.framed, - &mut this.dispatch_state, - &mut this.rx, - &mut this.inner, - &mut this.disconnect, - ) - } -} - -fn poll( - cx: &mut Context<'_>, - srv: &mut S, - state: &mut St, - sink: &mut Sink<::Item>, - framed: &mut Framed, - dispatch_state: &mut FramedState, - rx: &mut Option::Item>>>, - inner: &mut Cell::Item, S::Error>>, - disconnect: &mut Option>, -) -> Poll>> -where - St: Clone, - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - match mem::replace(dispatch_state, FramedState::Processing) { - FramedState::Processing => { - if poll_read(cx, srv, state, sink, framed, dispatch_state, inner) - || poll_write(cx, framed, dispatch_state, rx, inner) - { - poll( - cx, - srv, - state, - sink, - framed, - dispatch_state, - rx, - inner, - disconnect, - ) - } else { - Poll::Pending - } - } - FramedState::Error(err) => { - if framed.is_write_buf_empty() - || (poll_write(cx, framed, dispatch_state, rx, inner) - || framed.is_write_buf_empty()) - { - if let Some(ref disconnect) = disconnect { - (&*disconnect)(&mut *state, true); + match self.dispatch_state { + FramedState::Processing => { + if self.poll_read(cx) || self.poll_write(cx) { + self.poll(cx) + } else { + Poll::Pending } - Poll::Ready(Err(err)) - } else { - *dispatch_state = FramedState::Error(err); - Poll::Pending } - } - FramedState::FlushAndStop(mut vec) => { - if !framed.is_write_buf_empty() { - match Pin::new(framed).poll_flush(cx) { - Poll::Ready(Err(err)) => { - debug!("Error sending data: {:?}", err); - } - Poll::Pending => { - *dispatch_state = FramedState::FlushAndStop(vec); + FramedState::Error(_) => { + // flush write buffer + if !self.framed.is_write_buf_empty() { + if let Poll::Pending = self.framed.flush(cx) { return Poll::Pending; } - Poll::Ready(_) => (), } - }; - for tx in vec.drain(..) { - let _ = tx.send(()); + if let Some(ref disconnect) = self.disconnect { + (&*disconnect)(&mut self.state, true); + } + Poll::Ready(Err(self.dispatch_state.take_error())) } - if let Some(ref disconnect) = disconnect { - (&*disconnect)(&mut *state, false); - } - Poll::Ready(Ok(())) - } - FramedState::FramedError(err) => { - if let Some(ref disconnect) = disconnect { - (&*disconnect)(&mut *state, true); - } - Poll::Ready(Err(err)) - } - FramedState::Stopping => { - if let Some(ref disconnect) = disconnect { - (&*disconnect)(&mut *state, false); - } - Poll::Ready(Ok(())) - } - } -} - -fn poll_read( - cx: &mut Context<'_>, - srv: &mut S, - state: &mut St, - sink: &mut Sink<::Item>, - framed: &mut Framed, - dispatch_state: &mut FramedState, - inner: &mut Cell::Item, S::Error>>, -) -> bool -where - St: Clone, - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - loop { - match srv.poll_ready(cx) { - Poll::Ready(Ok(_)) => { - let item = match framed.next_item(cx) { - Poll::Ready(Some(Ok(el))) => el, - Poll::Ready(Some(Err(err))) => { - *dispatch_state = FramedState::FramedError(ServiceError::Decoder(err)); - return true; - } - Poll::Pending => return false, - Poll::Ready(None) => { - log::trace!("Client disconnected"); - *dispatch_state = FramedState::Stopping; - return true; + FramedState::FlushAndStop(ref mut vec) => { + if !self.framed.is_write_buf_empty() { + match self.framed.flush(cx) { + Poll::Ready(Err(err)) => { + debug!("Error sending data: {:?}", err); + } + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(_) => (), } }; - - let mut cell = inner.clone(); - actix_rt::spawn(srv.call(Item::new(state.clone(), sink.clone(), item)).then( - move |item| { - let item = match item { - Ok(Some(item)) => Ok(item), - Ok(None) => return ready(()), - Err(err) => Err(err), - }; - unsafe { - let inner = cell.get_mut(); - inner.buf.push_back(item); - inner.task.wake(); - } - ready(()) - }, - )); + for tx in vec.drain(..) { + let _ = tx.send(()); + } + if let Some(ref disconnect) = self.disconnect { + (&*disconnect)(&mut self.state, false); + } + Poll::Ready(Ok(())) } - Poll::Pending => return false, - Poll::Ready(Err(err)) => { - *dispatch_state = FramedState::Error(ServiceError::Service(err)); - return true; + FramedState::FramedError(_) => { + if let Some(ref disconnect) = self.disconnect { + (&*disconnect)(&mut self.state, true); + } + Poll::Ready(Err(self.dispatch_state.take_framed_error())) + } + FramedState::Stopping => { + if let Some(ref disconnect) = self.disconnect { + (&*disconnect)(&mut self.state, false); + } + Poll::Ready(Ok(())) } } } } - -/// write to framed object -fn poll_write( - cx: &mut Context<'_>, - framed: &mut Framed, - dispatch_state: &mut FramedState, - rx: &mut Option::Item>>>, - inner: &mut Cell::Item, S::Error>>, -) -> bool -where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, - ::Item: 'static, - ::Error: std::fmt::Debug, -{ - let inner = unsafe { inner.get_mut() }; - let mut rx_done = rx.is_none(); - let mut buf_empty = inner.buf.is_empty(); - loop { - while !framed.is_write_buf_full() { - if !buf_empty { - match inner.buf.pop_front().unwrap() { - Ok(msg) => { - if let Err(err) = framed.write(msg) { - *dispatch_state = - FramedState::FramedError(ServiceError::Encoder(err)); - return true; - } - buf_empty = inner.buf.is_empty(); - } - Err(err) => { - *dispatch_state = FramedState::Error(ServiceError::Service(err)); - return true; - } - } - } - - if !rx_done && rx.is_some() { - match Pin::new(rx.as_mut().unwrap()).poll_next(cx) { - Poll::Ready(Some(FramedMessage::Message(msg))) => { - if let Err(err) = framed.write(msg) { - *dispatch_state = - FramedState::FramedError(ServiceError::Encoder(err)); - return true; - } - } - Poll::Ready(Some(FramedMessage::Close)) => { - dispatch_state.stop(None); - return true; - } - Poll::Ready(Some(FramedMessage::WaitClose(tx))) => { - dispatch_state.stop(Some(tx)); - return true; - } - Poll::Ready(None) => { - rx_done = true; - let _ = rx.take(); - } - Poll::Pending => rx_done = true, - } - } - - if rx_done && buf_empty { - break; - } - } - - if !framed.is_write_buf_empty() { - match framed.flush(cx) { - Poll::Pending => break, - Poll::Ready(Err(err)) => { - debug!("Error sending data: {:?}", err); - *dispatch_state = FramedState::FramedError(ServiceError::Encoder(err)); - return true; - } - Poll::Ready(_) => (), - } - } else { - break; - } - } - - false -} diff --git a/actix-ioframe/src/item.rs b/actix-ioframe/src/item.rs index b0f5c4d8..03db8131 100644 --- a/actix-ioframe/src/item.rs +++ b/actix-ioframe/src/item.rs @@ -5,19 +5,19 @@ use actix_codec::{Decoder, Encoder}; use crate::sink::Sink; -pub struct Item { +pub struct Item { state: St, - sink: Sink<::Item>, + sink: Sink<::Item, E>, item: ::Item, } -impl Item +impl Item where Codec: Encoder + Decoder, { pub(crate) fn new( state: St, - sink: Sink<::Item>, + sink: Sink<::Item, E>, item: ::Item, ) -> Self { Item { state, sink, item } @@ -34,7 +34,7 @@ where } #[inline] - pub fn sink(&self) -> &Sink<::Item> { + pub fn sink(&self) -> &Sink<::Item, E> { &self.sink } @@ -44,12 +44,18 @@ where } #[inline] - pub fn into_parts(self) -> (St, Sink<::Item>, ::Item) { + pub fn into_parts( + self, + ) -> ( + St, + Sink<::Item, E>, + ::Item, + ) { (self.state, self.sink, self.item) } } -impl Deref for Item +impl Deref for Item where Codec: Encoder + Decoder, { @@ -61,7 +67,7 @@ where } } -impl DerefMut for Item +impl DerefMut for Item where Codec: Encoder + Decoder, { @@ -71,12 +77,12 @@ where } } -impl fmt::Debug for Item +impl fmt::Debug for Item where Codec: Encoder + Decoder, ::Item: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("FramedItem").field(&self.item).finish() + f.debug_tuple("Item").field(&self.item).finish() } } diff --git a/actix-ioframe/src/lib.rs b/actix-ioframe/src/lib.rs index b33f262c..cb4f4ae2 100644 --- a/actix-ioframe/src/lib.rs +++ b/actix-ioframe/src/lib.rs @@ -1,7 +1,6 @@ #![deny(rust_2018_idioms, warnings)] #![allow(clippy::type_complexity, clippy::too_many_arguments)] -mod cell; mod connect; mod dispatcher; mod error; diff --git a/actix-ioframe/src/service.rs b/actix-ioframe/src/service.rs index 13913edc..82576f04 100644 --- a/actix-ioframe/src/service.rs +++ b/actix-ioframe/src/service.rs @@ -6,17 +6,20 @@ use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; use actix_service::{IntoService, IntoServiceFactory, Service, ServiceFactory}; +use actix_utils::mpsc; use either::Either; use futures::future::{FutureExt, LocalBoxFuture}; use pin_project::project; use crate::connect::{Connect, ConnectResult}; -use crate::dispatcher::FramedDispatcher; +use crate::dispatcher::{Dispatcher, Message}; use crate::error::ServiceError; use crate::item::Item; +use crate::sink::Sink; -type RequestItem = Item; +type RequestItem = Item; type ResponseItem = Option<::Item>; +type ServiceResult = Result::Item>, E>; /// Service builder - structure that follows the builder pattern /// for building instances for framed services. @@ -34,11 +37,15 @@ impl Builder { } /// Construct framed handler service with specified connect service - pub fn service(self, connect: F) -> ServiceBuilder + pub fn service(self, connect: F) -> ServiceBuilder where F: IntoService, Io: AsyncRead + AsyncWrite, - C: Service, Response = ConnectResult>, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = E, + >, Codec: Decoder + Encoder, { ServiceBuilder { @@ -49,16 +56,17 @@ impl Builder { } /// Construct framed handler new service with specified connect service - pub fn factory(self, connect: F) -> NewServiceBuilder + pub fn factory(self, connect: F) -> NewServiceBuilder where F: IntoServiceFactory, + E: 'static, Io: AsyncRead + AsyncWrite, C: ServiceFactory< Config = (), - Request = Connect, - Response = ConnectResult, + Request = Connect, + Response = ConnectResult, + Error = E, >, - C::Error: 'static, C::Future: 'static, Codec: Decoder + Encoder, { @@ -70,17 +78,20 @@ impl Builder { } } -pub struct ServiceBuilder { +pub struct ServiceBuilder { connect: C, disconnect: Option>, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Io, Codec, Err)>, } -impl ServiceBuilder +impl ServiceBuilder where St: Clone, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, Io: AsyncRead + AsyncWrite, Codec: Decoder + Encoder, ::Item: 'static, @@ -98,16 +109,16 @@ where } /// Provide stream items handler service and construct service factory. - pub fn finish(self, service: F) -> FramedServiceImpl + pub fn finish(self, service: F) -> FramedServiceImpl where F: IntoServiceFactory, T: ServiceFactory< - Config = St, - Request = RequestItem, - Response = ResponseItem, - Error = C::Error, - InitError = C::Error, - > + 'static, + Config = St, + Request = RequestItem, + Response = ResponseItem, + Error = Err, + InitError = Err, + >, { FramedServiceImpl { connect: self.connect, @@ -118,22 +129,23 @@ where } } -pub struct NewServiceBuilder { +pub struct NewServiceBuilder { connect: C, disconnect: Option>, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Io, Codec, Err)>, } -impl NewServiceBuilder +impl NewServiceBuilder where St: Clone, Io: AsyncRead + AsyncWrite, + Err: 'static, C: ServiceFactory< Config = (), - Request = Connect, - Response = ConnectResult, + Request = Connect, + Response = ConnectResult, + Error = Err, >, - C::Error: 'static, C::Future: 'static, Codec: Decoder + Encoder, ::Item: 'static, @@ -150,15 +162,15 @@ where self } - pub fn finish(self, service: F) -> FramedService + pub fn finish(self, service: F) -> FramedService where F: IntoServiceFactory, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, > + 'static, { FramedService { @@ -170,32 +182,34 @@ where } } -pub struct FramedService { +pub struct FramedService { connect: C, handler: Rc, disconnect: Option>, - _t: PhantomData<(St, Io, Codec, Cfg)>, + _t: PhantomData<(St, Io, Codec, Err, Cfg)>, } -impl ServiceFactory for FramedService +impl ServiceFactory + for FramedService where St: Clone + 'static, Io: AsyncRead + AsyncWrite, C: ServiceFactory< Config = (), - Request = Connect, - Response = ConnectResult, + Request = Connect, + Response = ConnectResult, + Error = Err, >, - C::Error: 'static, C::Future: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, > + 'static, ::Future: 'static, + Err: 'static, Codec: Decoder + Encoder, ::Item: 'static, ::Error: std::fmt::Debug, @@ -205,7 +219,7 @@ where type Response = (); type Error = ServiceError; type InitError = C::InitError; - type Service = FramedServiceImpl; + type Service = FramedServiceImpl; type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: Cfg) -> Self::Future { @@ -227,25 +241,29 @@ where } } -pub struct FramedServiceImpl { +pub struct FramedServiceImpl { connect: C, handler: Rc, disconnect: Option>, - _t: PhantomData<(St, Io, Codec)>, + _t: PhantomData<(St, Io, Codec, Err)>, } -impl Service for FramedServiceImpl +impl Service for FramedServiceImpl where St: Clone, Io: AsyncRead + AsyncWrite, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, + Err: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, >, ::Future: 'static, Codec: Decoder + Encoder, @@ -254,36 +272,43 @@ where { type Request = Io; type Response = (); - type Error = ServiceError; - type Future = FramedServiceImplResponse; + type Error = ServiceError; + type Future = FramedServiceImplResponse; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.connect.poll_ready(cx).map_err(|e| e.into()) } fn call(&mut self, req: Io) -> Self::Future { + let (tx, rx) = mpsc::channel(); + let sink = Sink::new(tx); FramedServiceImplResponse { inner: FramedServiceImplResponseInner::Connect( - self.connect.call(Connect::new(req)), + self.connect.call(Connect::new(req, sink.clone())), self.handler.clone(), self.disconnect.clone(), + Some(rx), ), } } } #[pin_project::pin_project] -pub struct FramedServiceImplResponse +pub struct FramedServiceImplResponse where St: Clone, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, + Err: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, >, ::Future: 'static, Io: AsyncRead + AsyncWrite, @@ -292,20 +317,24 @@ where ::Error: std::fmt::Debug, { #[pin] - inner: FramedServiceImplResponseInner, + inner: FramedServiceImplResponseInner, } -impl Future for FramedServiceImplResponse +impl Future for FramedServiceImplResponse where St: Clone, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, + Err: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, >, ::Future: 'static, Io: AsyncRead + AsyncWrite, @@ -313,7 +342,7 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { - type Output = Result<(), ServiceError>; + type Output = Result<(), ServiceError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.as_mut().project(); @@ -331,17 +360,21 @@ where } #[pin_project::pin_project] -enum FramedServiceImplResponseInner +enum FramedServiceImplResponseInner where St: Clone, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, + Err: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, >, ::Future: 'static, Io: AsyncRead + AsyncWrite, @@ -349,26 +382,36 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { - Connect(#[pin] C::Future, Rc, Option>), + Connect( + #[pin] C::Future, + Rc, + Option>, + Option>>, + ), Handler( #[pin] T::Future, - Option>, + Option>, Option>, + Option>>, ), - Dispatcher(#[pin] FramedDispatcher), + Dispatcher(Dispatcher), } -impl FramedServiceImplResponseInner +impl FramedServiceImplResponseInner where St: Clone, - C: Service, Response = ConnectResult>, - C::Error: 'static, + C: Service< + Request = Connect, + Response = ConnectResult, + Error = Err, + >, + Err: 'static, T: ServiceFactory< Config = St, - Request = RequestItem, + Request = RequestItem, Response = ResponseItem, - Error = C::Error, - InitError = C::Error, + Error = Err, + InitError = Err, >, ::Future: 'static, Io: AsyncRead + AsyncWrite, @@ -381,42 +424,44 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Either< - FramedServiceImplResponseInner, - Poll>>, + FramedServiceImplResponseInner, + Poll>>, > { #[project] match self.project() { - FramedServiceImplResponseInner::Connect(fut, handler, disconnect) => { + FramedServiceImplResponseInner::Connect(fut, handler, disconnect, rx) => { match fut.poll(cx) { Poll::Ready(Ok(res)) => { Either::Left(FramedServiceImplResponseInner::Handler( handler.new_service(res.state.clone()), Some(res), disconnect.take(), + rx.take(), )) } Poll::Pending => Either::Right(Poll::Pending), Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))), } } - FramedServiceImplResponseInner::Handler(fut, res, disconnect) => match fut.poll(cx) - { - Poll::Ready(Ok(handler)) => { - let res = res.take().unwrap(); - Either::Left(FramedServiceImplResponseInner::Dispatcher( - FramedDispatcher::new( - res.framed, - res.state, - handler, - res.rx, - res.sink, - disconnect.take(), - ), - )) + FramedServiceImplResponseInner::Handler(fut, res, disconnect, rx) => { + match fut.poll(cx) { + Poll::Ready(Ok(handler)) => { + let res = res.take().unwrap(); + Either::Left(FramedServiceImplResponseInner::Dispatcher( + Dispatcher::new( + res.framed, + res.state, + handler, + res.sink, + rx.take().unwrap(), + disconnect.take(), + ), + )) + } + Poll::Pending => Either::Right(Poll::Pending), + Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))), } - Poll::Pending => Either::Right(Poll::Pending), - Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))), - }, + } FramedServiceImplResponseInner::Dispatcher(ref mut fut) => { Either::Right(fut.poll(cx)) } diff --git a/actix-ioframe/src/sink.rs b/actix-ioframe/src/sink.rs index 43c0d574..4e0fe025 100644 --- a/actix-ioframe/src/sink.rs +++ b/actix-ioframe/src/sink.rs @@ -3,41 +3,41 @@ use std::fmt; use actix_utils::{mpsc, oneshot}; use futures::future::{Future, FutureExt}; -use crate::dispatcher::FramedMessage; +use crate::dispatcher::Message; -pub struct Sink(mpsc::Sender>); +pub struct Sink(mpsc::Sender, E>>); -impl Clone for Sink { +impl Clone for Sink { fn clone(&self) -> Self { Sink(self.0.clone()) } } -impl Sink { - pub(crate) fn new(tx: mpsc::Sender>) -> Self { +impl Sink { + pub(crate) fn new(tx: mpsc::Sender, E>>) -> Self { Sink(tx) } /// Close connection pub fn close(&self) { - let _ = self.0.send(FramedMessage::Close); + let _ = self.0.send(Ok(Message::Close)); } /// Close connection pub fn wait_close(&self) -> impl Future { let (tx, rx) = oneshot::channel(); - let _ = self.0.send(FramedMessage::WaitClose(tx)); + let _ = self.0.send(Ok(Message::WaitClose(tx))); rx.map(|_| ()) } /// Send item pub fn send(&self, item: T) { - let _ = self.0.send(FramedMessage::Message(item)); + let _ = self.0.send(Ok(Message::Item(item))); } } -impl fmt::Debug for Sink { +impl fmt::Debug for Sink { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Sink").finish() } diff --git a/actix-ioframe/tests/test_server.rs b/actix-ioframe/tests/test_server.rs index 81da99c0..35bdb983 100644 --- a/actix-ioframe/tests/test_server.rs +++ b/actix-ioframe/tests/test_server.rs @@ -22,7 +22,7 @@ async fn test_disconnect() -> std::io::Result<()> { let disconnect1 = disconnect1.clone(); Builder::new() - .factory(fn_service(|conn: Connect<_>| { + .factory(fn_service(|conn: Connect<_, _, _>| { ok(conn.codec(BytesCodec).state(State)) })) .disconnect(move |_, _| { @@ -32,7 +32,7 @@ async fn test_disconnect() -> std::io::Result<()> { }); let mut client = Builder::new() - .service(|conn: Connect<_>| { + .service(|conn: Connect<_, _, _>| { let conn = conn.codec(BytesCodec).state(State); conn.sink().close(); ok(conn)