diff --git a/actix-ioframe/src/cell.rs b/actix-ioframe/src/cell.rs index 419709ce..1794dd66 100644 --- a/actix-ioframe/src/cell.rs +++ b/actix-ioframe/src/cell.rs @@ -29,10 +29,6 @@ impl Cell { } } - pub fn get_ref(&self) -> &T { - unsafe { &*self.inner.as_ref().get() } - } - pub fn get_mut(&mut self) -> &mut T { unsafe { &mut *self.inner.as_ref().get() } } diff --git a/actix-ioframe/src/connect.rs b/actix-ioframe/src/connect.rs index 406066f2..29190228 100644 --- a/actix-ioframe/src/connect.rs +++ b/actix-ioframe/src/connect.rs @@ -1,140 +1,45 @@ +use std::marker::PhantomData; + use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use futures::unsync::mpsc; -use crate::cell::Cell; use crate::dispatcher::FramedMessage; use crate::sink::Sink; pub struct Connect { - io: IoItem, - state: St, + io: Io, + _t: PhantomData<(St, Codec)>, } -enum IoItem { - Io(Io), - Framed(Framed), -} - -impl IoItem +impl Connect where Io: AsyncRead + AsyncWrite, { - fn into_codec(self, codec: Codec) -> IoItem - where - Codec: Encoder + Decoder, - { - match self { - IoItem::Io(io) => IoItem::Framed(Framed::new(io, codec)), - IoItem::Framed(framed) => IoItem::Framed(framed.into_framed(codec)), - } - } - - fn as_framed(&mut self) -> &mut Framed - where - C: Encoder + Decoder, - { - match self { - IoItem::Io(_) => panic!("Codec is not set"), - IoItem::Framed(ref mut framed) => framed, - } - } - - fn into_framed(self) -> Framed - where - C: Encoder + Decoder, - { - match self { - IoItem::Io(_) => panic!("Codec is not set"), - IoItem::Framed(framed) => framed, - } - } -} - -impl Connect { pub(crate) fn new(io: Io) -> Self { Self { - io: IoItem::Io(io), - state: (), + io, + _t: PhantomData, } } -} -impl Connect -where - Io: AsyncRead + AsyncWrite, -{ - pub fn codec(self, codec: Codec) -> Connect + pub fn codec(self, codec: Codec) -> ConnectResult where Codec: Encoder + Decoder, { - Connect { - io: self.io.into_codec(codec), - state: self.state, - } - } - - pub fn state(self, state: St) -> Connect { - Connect { state, io: self.io } - } -} - -impl Connect -where - C: Encoder + Decoder, - Io: AsyncRead + AsyncWrite, -{ - pub fn into_result(self) -> ConnectResult { let (tx, rx) = mpsc::unbounded(); let sink = Sink::new(tx); ConnectResult { - state: Cell::new(self.state), - framed: self.io.into_framed(), + state: (), + framed: Framed::new(self.io, codec), rx, sink, } } } -impl futures::Stream for Connect -where - Codec: Encoder + Decoder, - Io: AsyncRead + AsyncWrite, -{ - type Item = ::Item; - type Error = ::Error; - - fn poll(&mut self) -> futures::Poll, Self::Error> { - self.io.as_framed().poll() - } -} - -impl futures::Sink for Connect -where - Codec: Encoder + Decoder, - Io: AsyncRead + AsyncWrite, -{ - type SinkItem = ::Item; - type SinkError = ::Error; - - fn start_send( - &mut self, - item: Self::SinkItem, - ) -> futures::StartSend { - self.io.as_framed().start_send(item) - } - - fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> { - self.io.as_framed().poll_complete() - } - - fn close(&mut self) -> futures::Poll<(), Self::SinkError> { - self.io.as_framed().close() - } -} - pub struct ConnectResult { - pub(crate) state: Cell, + pub(crate) state: St, pub(crate) framed: Framed, pub(crate) rx: mpsc::UnboundedReceiver::Item>>, pub(crate) sink: Sink<::Item>, @@ -145,6 +50,16 @@ impl ConnectResult { pub fn sink(&self) -> &Sink<::Item> { &self.sink } + + #[inline] + pub fn state(self, state: S) -> ConnectResult { + ConnectResult { + state, + framed: self.framed, + rx: self.rx, + sink: self.sink, + } + } } impl futures::Stream for ConnectResult diff --git a/actix-ioframe/src/dispatcher.rs b/actix-ioframe/src/dispatcher.rs index e27e62d7..02f0da51 100644 --- a/actix-ioframe/src/dispatcher.rs +++ b/actix-ioframe/src/dispatcher.rs @@ -13,6 +13,7 @@ use crate::cell::Cell; use crate::error::ServiceError; use crate::item::Item; use crate::sink::Sink; +use crate::state::State; type Request = Item; type Response = ::Item; @@ -37,8 +38,8 @@ where { service: S, sink: Sink<::Item>, - state: Cell, - dispatch_state: State, + state: State, + dispatch_state: FramedState, framed: Framed, rx: Option::Item>>>, inner: Cell::Item, S::Error>>, @@ -56,7 +57,7 @@ where { pub(crate) fn new>( framed: Framed, - state: Cell, + state: State, service: F, rx: mpsc::UnboundedReceiver::Item>>, sink: Sink<::Item>, @@ -67,7 +68,7 @@ where sink, rx: Some(rx), service: service.into_service(), - dispatch_state: State::Processing, + dispatch_state: FramedState::Processing, inner: Cell::new(FramedDispatcherInner { buf: VecDeque::new(), task: AtomicTask::new(), @@ -76,7 +77,7 @@ where } } -enum State { +enum FramedState { Processing, Error(ServiceError), FramedError(ServiceError), @@ -84,22 +85,22 @@ enum State { Stopping, } -impl State { +impl FramedState { fn stop(&mut self, tx: Option>) { match self { - State::FlushAndStop(ref mut vec) => { + FramedState::FlushAndStop(ref mut vec) => { if let Some(tx) = tx { vec.push(tx) } } - State::Processing => { - *self = State::FlushAndStop(if let Some(tx) = tx { + FramedState::Processing => { + *self = FramedState::FlushAndStop(if let Some(tx) = tx { vec![tx] } else { Vec::new() }) } - State::Error(_) | State::FramedError(_) | State::Stopping => { + FramedState::Error(_) | FramedState::FramedError(_) | FramedState::Stopping => { if let Some(tx) = tx { let _ = tx.send(()); } @@ -131,12 +132,12 @@ where Ok(Async::Ready(Some(el))) => el, Err(err) => { self.dispatch_state = - State::FramedError(ServiceError::Decoder(err)); + FramedState::FramedError(ServiceError::Decoder(err)); return true; } Ok(Async::NotReady) => return false, Ok(Async::Ready(None)) => { - self.dispatch_state = State::Stopping; + self.dispatch_state = FramedState::Stopping; return true; } }; @@ -161,7 +162,7 @@ where } Ok(Async::NotReady) => return false, Err(err) => { - self.dispatch_state = State::Error(ServiceError::Service(err)); + self.dispatch_state = FramedState::Error(ServiceError::Service(err)); return true; } } @@ -180,13 +181,14 @@ where Ok(msg) => { if let Err(err) = self.framed.force_send(msg) { self.dispatch_state = - State::FramedError(ServiceError::Encoder(err)); + FramedState::FramedError(ServiceError::Encoder(err)); return true; } buf_empty = inner.buf.is_empty(); } Err(err) => { - self.dispatch_state = State::Error(ServiceError::Service(err)); + self.dispatch_state = + FramedState::Error(ServiceError::Service(err)); return true; } } @@ -197,7 +199,7 @@ where Ok(Async::Ready(Some(FramedMessage::Message(msg)))) => { if let Err(err) = self.framed.force_send(msg) { self.dispatch_state = - State::FramedError(ServiceError::Encoder(err)); + FramedState::FramedError(ServiceError::Encoder(err)); return true; } } @@ -231,7 +233,8 @@ where Ok(Async::NotReady) => break, Err(err) => { debug!("Error sending data: {:?}", err); - self.dispatch_state = State::FramedError(ServiceError::Encoder(err)); + self.dispatch_state = + FramedState::FramedError(ServiceError::Encoder(err)); return true; } Ok(Async::Ready(_)) => (), @@ -259,32 +262,32 @@ where type Error = ServiceError; fn poll(&mut self) -> Poll { - match mem::replace(&mut self.dispatch_state, State::Processing) { - State::Processing => { + match mem::replace(&mut self.dispatch_state, FramedState::Processing) { + FramedState::Processing => { if self.poll_read() || self.poll_write() { self.poll() } else { Ok(Async::NotReady) } } - State::Error(err) => { + FramedState::Error(err) => { if self.framed.is_write_buf_empty() || (self.poll_write() || self.framed.is_write_buf_empty()) { Err(err) } else { - self.dispatch_state = State::Error(err); + self.dispatch_state = FramedState::Error(err); Ok(Async::NotReady) } } - State::FlushAndStop(mut vec) => { + FramedState::FlushAndStop(mut vec) => { if !self.framed.is_write_buf_empty() { match self.framed.poll_complete() { Err(err) => { debug!("Error sending data: {:?}", err); } Ok(Async::NotReady) => { - self.dispatch_state = State::FlushAndStop(vec); + self.dispatch_state = FramedState::FlushAndStop(vec); return Ok(Async::NotReady); } Ok(Async::Ready(_)) => (), @@ -295,8 +298,8 @@ where } Ok(Async::Ready(())) } - State::FramedError(err) => Err(err), - State::Stopping => Ok(Async::Ready(())), + FramedState::FramedError(err) => Err(err), + FramedState::Stopping => Ok(Async::Ready(())), } } } diff --git a/actix-ioframe/src/item.rs b/actix-ioframe/src/item.rs index f326d0a9..b8d4ae31 100644 --- a/actix-ioframe/src/item.rs +++ b/actix-ioframe/src/item.rs @@ -1,13 +1,14 @@ +use std::cell::{Ref, RefMut}; use std::fmt; use std::ops::{Deref, DerefMut}; use actix_codec::{Decoder, Encoder}; -use crate::cell::Cell; use crate::sink::Sink; +use crate::state::State; -pub struct Item { - state: Cell, +pub struct Item { + state: State, sink: Sink<::Item>, item: ::Item, } @@ -17,7 +18,7 @@ where Codec: Encoder + Decoder, { pub(crate) fn new( - state: Cell, + state: State, sink: Sink<::Item>, item: ::Item, ) -> Self { @@ -25,12 +26,12 @@ where } #[inline] - pub fn state(&self) -> &St { + pub fn state(&self) -> Ref { self.state.get_ref() } #[inline] - pub fn state_mut(&mut self) -> &mut St { + pub fn state_mut(&mut self) -> RefMut { self.state.get_mut() } @@ -43,6 +44,17 @@ where pub fn into_inner(self) -> ::Item { self.item } + + #[inline] + pub fn into_parts( + self, + ) -> ( + State, + Sink<::Item>, + ::Item, + ) { + (self.state, self.sink, self.item) + } } impl Deref for Item diff --git a/actix-ioframe/src/lib.rs b/actix-ioframe/src/lib.rs index fabfaa8d..5e25be2c 100644 --- a/actix-ioframe/src/lib.rs +++ b/actix-ioframe/src/lib.rs @@ -5,9 +5,11 @@ mod error; mod item; mod service; mod sink; +mod state; pub use self::connect::{Connect, ConnectResult}; pub use self::error::ServiceError; pub use self::item::Item; pub use self::service::{Builder, NewServiceBuilder, ServiceBuilder}; pub use self::sink::Sink; +pub use self::state::State; diff --git a/actix-ioframe/src/service.rs b/actix-ioframe/src/service.rs index 8ca9a964..b31b88ef 100644 --- a/actix-ioframe/src/service.rs +++ b/actix-ioframe/src/service.rs @@ -9,6 +9,7 @@ use crate::connect::{Connect, ConnectResult}; use crate::dispatcher::FramedDispatcher; use crate::error::ServiceError; use crate::item::Item; +use crate::state::State; type RequestItem = Item; type ResponseItem = Option<::Item>; @@ -295,7 +296,7 @@ where match fut.poll()? { Async::Ready(res) => { self.inner = FramedServiceImplResponseInner::Handler( - handler.new_service(res.state.get_ref()), + handler.new_service(&res.state), Some(res), ); self.poll() @@ -309,7 +310,11 @@ where let res = res.take().unwrap(); self.inner = FramedServiceImplResponseInner::Dispatcher(FramedDispatcher::new( - res.framed, res.state, handler, res.rx, res.sink, + res.framed, + State::new(res.state), + handler, + res.rx, + res.sink, )); self.poll() } diff --git a/actix-ioframe/src/state.rs b/actix-ioframe/src/state.rs new file mode 100644 index 00000000..7349b864 --- /dev/null +++ b/actix-ioframe/src/state.rs @@ -0,0 +1,30 @@ +use std::cell::{Ref, RefCell, RefMut}; +use std::rc::Rc; + +/// Connection state +/// +/// Connection state is an arbitrary data attached to the each incoming message. +#[derive(Debug)] +pub struct State(Rc>); + +impl State { + pub(crate) fn new(st: T) -> Self { + State(Rc::new(RefCell::new(st))) + } + + #[inline] + pub fn get_ref(&self) -> Ref { + self.0.borrow() + } + + #[inline] + pub fn get_mut(&mut self) -> RefMut { + self.0.borrow_mut() + } +} + +impl Clone for State { + fn clone(&self) -> Self { + State(self.0.clone()) + } +}