diff --git a/actix-ioframe/src/dispatcher.rs b/actix-ioframe/src/dispatcher.rs index 02f0da51..fa7bfe42 100644 --- a/actix-ioframe/src/dispatcher.rs +++ b/actix-ioframe/src/dispatcher.rs @@ -1,6 +1,7 @@ //! Framed dispatcher service and related utilities use std::collections::VecDeque; use std::mem; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_service::{IntoService, Service}; @@ -43,6 +44,7 @@ where framed: Framed, rx: Option::Item>>>, inner: Cell::Item, S::Error>>, + disconnect: Option>, } impl FramedDispatcher @@ -61,11 +63,13 @@ where service: F, rx: mpsc::UnboundedReceiver::Item>>, sink: Sink<::Item>, + disconnect: Option>, ) -> Self { FramedDispatcher { framed, state, sink, + disconnect, rx: Some(rx), service: service.into_service(), dispatch_state: FramedState::Processing, @@ -124,6 +128,12 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { + fn disconnect(&mut self, error: bool) { + if let Some(ref disconnect) = self.disconnect { + (&*disconnect)(&mut *self.state.get_mut(), error); + } + } + fn poll_read(&mut self) -> bool { loop { match self.service.poll_ready() { @@ -274,6 +284,7 @@ where if self.framed.is_write_buf_empty() || (self.poll_write() || self.framed.is_write_buf_empty()) { + self.disconnect(true); Err(err) } else { self.dispatch_state = FramedState::Error(err); @@ -296,9 +307,13 @@ where for tx in vec.drain(..) { let _ = tx.send(()); } + self.disconnect(false); Ok(Async::Ready(())) } - FramedState::FramedError(err) => Err(err), + FramedState::FramedError(err) => { + self.disconnect(true); + Err(err) + } FramedState::Stopping => Ok(Async::Ready(())), } } diff --git a/actix-ioframe/src/service.rs b/actix-ioframe/src/service.rs index b31b88ef..0e32c4c3 100644 --- a/actix-ioframe/src/service.rs +++ b/actix-ioframe/src/service.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; use actix_service::{IntoNewService, IntoService, NewService, Service}; -use futures::{Async, Future, Poll}; +use futures::{Async, Future, IntoFuture, Poll}; use crate::connect::{Connect, ConnectResult}; use crate::dispatcher::FramedDispatcher; @@ -33,6 +33,7 @@ impl Builder { { ServiceBuilder { connect: connect.into_service(), + disconnect: None, _t: PhantomData, } } @@ -53,6 +54,7 @@ impl Builder { { NewServiceBuilder { connect: connect.into_new_service(), + disconnect: None, _t: PhantomData, } } @@ -60,6 +62,7 @@ impl Builder { pub struct ServiceBuilder { connect: C, + disconnect: Option>, _t: PhantomData<(St, Io, Codec)>, } @@ -73,6 +76,23 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { + /// Callback to execute on disconnect + /// + /// Second parameter indicates error occured during disconnect. + pub fn disconnect(mut self, disconnect: F) -> Self + where + F: Fn(&mut St, bool) -> Out + 'static, + Out: IntoFuture, + Out::Future: 'static, + { + self.disconnect = Some(Rc::new(move |st, error| { + let fut = disconnect(st, error).into_future(); + tokio_current_thread::spawn(fut.map_err(|_| ()).map(|_| ())); + })); + self + } + + /// Provide stream items handler service and construct service factory. pub fn finish( self, service: F, @@ -90,6 +110,7 @@ where FramedServiceImpl { connect: self.connect, handler: Rc::new(service.into_new_service()), + disconnect: self.disconnect.clone(), _t: PhantomData, } } @@ -97,6 +118,7 @@ where pub struct NewServiceBuilder { connect: C, + disconnect: Option>, _t: PhantomData<(St, Io, Codec)>, } @@ -111,6 +133,22 @@ where ::Item: 'static, ::Error: std::fmt::Debug, { + /// Callback to execute on disconnect + /// + /// Second parameter indicates error occured during disconnect. + pub fn disconnect(mut self, disconnect: F) -> Self + where + F: Fn(&mut St, bool) -> Out + 'static, + Out: IntoFuture, + Out::Future: 'static, + { + self.disconnect = Some(Rc::new(move |st, error| { + let fut = disconnect(st, error).into_future(); + tokio_current_thread::spawn(fut.map_err(|_| ()).map(|_| ())); + })); + self + } + pub fn finish( self, service: F, @@ -133,6 +171,7 @@ where FramedService { connect: self.connect, handler: Rc::new(service.into_new_service()), + disconnect: self.disconnect, _t: PhantomData, } } @@ -141,6 +180,7 @@ where pub(crate) struct FramedService { connect: C, handler: Rc, + disconnect: Option>, _t: PhantomData<(St, Io, Codec)>, } @@ -172,6 +212,7 @@ where fn new_service(&self, _: &()) -> Self::Future { let handler = self.handler.clone(); + let disconnect = self.disconnect.clone(); // create connect service and then create service impl Box::new( @@ -180,6 +221,7 @@ where .map(move |connect| FramedServiceImpl { connect, handler, + disconnect, _t: PhantomData, }), ) @@ -189,6 +231,7 @@ where pub struct FramedServiceImpl { connect: C, handler: Rc, + disconnect: Option>, _t: PhantomData<(St, Io, Codec)>, } @@ -224,6 +267,7 @@ where self.connect.call(Connect::new(req)), self.handler.clone(), ), + disconnect: self.disconnect.clone(), } } } @@ -246,6 +290,7 @@ where ::Error: std::fmt::Debug, { inner: FramedServiceImplResponseInner, + disconnect: Option>, } enum FramedServiceImplResponseInner @@ -315,6 +360,7 @@ where handler, res.rx, res.sink, + self.disconnect.clone(), )); self.poll() }