use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_utils::oneshot; use actix_utils::task::LocalWaker; use actix_utils::time::LowResTimeService; use futures::future::{err, Either}; use futures::{future, Sink, Stream}; use fxhash::FxHashMap; use amqp_codec::protocol::{Begin, Close, End, Error, Frame}; use amqp_codec::{AmqpCodec, AmqpCodecError, AmqpFrame}; use crate::cell::{Cell, WeakCell}; use crate::errors::AmqpTransportError; use crate::hb::{Heartbeat, HeartbeatAction}; use crate::session::{Session, SessionInner}; use crate::Configuration; pub struct Connection { inner: Cell, framed: Framed>, hb: Heartbeat, } pub(crate) enum ChannelState { Opening(Option>, WeakCell), Established(Cell), Closing(Option>>), } impl ChannelState { fn is_opening(&self) -> bool { match self { ChannelState::Opening(_, _) => true, _ => false, } } } pub(crate) struct ConnectionInner { local: Configuration, remote: Configuration, write_queue: VecDeque, write_task: LocalWaker, sessions: slab::Slab, sessions_map: FxHashMap, error: Option, state: State, } #[derive(PartialEq)] enum State { Normal, Closing, RemoteClose, Drop, } impl Connection { pub fn new( framed: Framed>, local: Configuration, remote: Configuration, time: Option, ) -> Connection { Connection { framed, hb: Heartbeat::new( local.timeout().unwrap(), remote.timeout(), time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))), ), inner: Cell::new(ConnectionInner::new(local, remote)), } } pub(crate) fn new_server( framed: Framed>, inner: Cell, time: Option, ) -> Connection { let l_timeout = inner.get_ref().local.timeout().unwrap(); let r_timeout = inner.get_ref().remote.timeout(); Connection { framed, inner, hb: Heartbeat::new( l_timeout, r_timeout, time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))), ), } } /// Connection controller pub fn controller(&self) -> ConnectionController { ConnectionController(self.inner.clone()) } /// Get remote configuration pub fn remote_config(&self) -> &Configuration { &self.inner.get_ref().remote } /// Gracefully close connection pub fn close(&mut self) -> impl Future> { future::ok(()) } // TODO: implement /// Close connection with error pub fn close_with_error( &mut self, _err: Error, ) -> impl Future> { future::ok(()) } /// Opens the session pub fn open_session(&mut self) -> impl Future> { let cell = self.inner.downgrade(); let inner = self.inner.clone(); async move { let inner = inner.get_mut(); if let Some(ref e) = inner.error { Err(e.clone()) } else { let (tx, rx) = oneshot::channel(); let entry = inner.sessions.vacant_entry(); let token = entry.key(); if token >= inner.local.channel_max { Err(AmqpTransportError::TooManyChannels) } else { entry.insert(ChannelState::Opening(Some(tx), cell)); let begin = Begin { remote_channel: None, next_outgoing_id: 1, incoming_window: std::u32::MAX, outgoing_window: std::u32::MAX, handle_max: std::u32::MAX, offered_capabilities: None, desired_capabilities: None, properties: None, }; inner.post_frame(AmqpFrame::new(token as u16, begin.into())); rx.await.map_err(|_| AmqpTransportError::Disconnected) } } } } /// Get session by id. This method panics if session does not exists or in opening/closing state. pub(crate) fn get_session(&self, id: usize) -> Cell { if let Some(channel) = self.inner.get_ref().sessions.get(id) { if let ChannelState::Established(ref session) = channel { return session.clone(); } } panic!("Session not found: {}", id); } pub(crate) fn register_remote_session(&mut self, channel_id: u16, begin: &Begin) { trace!("remote session opened: {:?}", channel_id); let cell = self.inner.clone(); let inner = self.inner.get_mut(); let entry = inner.sessions.vacant_entry(); let token = entry.key(); let session = Cell::new(SessionInner::new( token, false, ConnectionController(cell), token as u16, begin.next_outgoing_id(), begin.incoming_window(), begin.outgoing_window(), )); entry.insert(ChannelState::Established(session)); inner.sessions_map.insert(channel_id, token); let begin = Begin { remote_channel: Some(channel_id), next_outgoing_id: 1, incoming_window: std::u32::MAX, outgoing_window: begin.incoming_window(), handle_max: std::u32::MAX, offered_capabilities: None, desired_capabilities: None, properties: None, }; inner.post_frame(AmqpFrame::new(token as u16, begin.into())); } pub(crate) fn send_frame(&mut self, frame: AmqpFrame) { self.inner.get_mut().post_frame(frame) } pub(crate) fn register_write_task(&self, cx: &mut Context) { self.inner.write_task.register(cx.waker()); } pub(crate) fn poll_outgoing(&mut self, cx: &mut Context) -> Poll> { let inner = self.inner.get_mut(); let mut update = false; loop { while !self.framed.is_write_buf_full() { if let Some(frame) = inner.pop_next_frame() { trace!("outgoing: {:#?}", frame); update = true; if let Err(e) = self.framed.write(frame) { inner.set_error(e.clone().into()); return Poll::Ready(Err(e)); } } else { break; } } if !self.framed.is_write_buf_empty() { match self.framed.flush(cx) { Poll::Pending => break, Poll::Ready(Err(e)) => { trace!("error sending data: {}", e); inner.set_error(e.clone().into()); return Poll::Ready(Err(e)); } Poll::Ready(_) => (), } } else { break; } } self.hb.update_remote(update); if inner.state == State::Drop { Poll::Ready(Ok(())) } else if inner.state == State::RemoteClose && inner.write_queue.is_empty() && self.framed.is_write_buf_empty() { Poll::Ready(Ok(())) } else { Poll::Pending } } pub(crate) fn poll_incoming( &mut self, cx: &mut Context, ) -> Poll>> { let inner = self.inner.get_mut(); let mut update = false; loop { match Pin::new(&mut self.framed).poll_next(cx) { Poll::Ready(Some(Ok(frame))) => { trace!("incoming: {:#?}", frame); update = true; if let Frame::Empty = frame.performative() { self.hb.update_local(update); continue; } // handle connection close if let Frame::Close(ref close) = frame.performative() { inner.set_error(AmqpTransportError::Closed(close.error.clone())); if inner.state == State::Closing { inner.sessions.clear(); return Poll::Ready(None); } else { let close = Close { error: None }; inner.post_frame(AmqpFrame::new(0, close.into())); inner.state = State::RemoteClose; } } if inner.error.is_some() { error!("connection closed but new framed is received: {:?}", frame); return Poll::Ready(None); } // get local session id let channel_id = if let Some(token) = inner.sessions_map.get(&frame.channel_id()) { *token } else { // we dont have channel info, only Begin frame is allowed on new channel if let Frame::Begin(ref begin) = frame.performative() { if begin.remote_channel().is_some() { inner.complete_session_creation(frame.channel_id(), begin); } else { return Poll::Ready(Some(Ok(frame))); } } else { warn!("Unexpected frame: {:#?}", frame); } continue; }; // handle session frames if let Some(channel) = inner.sessions.get_mut(channel_id) { match channel { ChannelState::Opening(_, _) => { error!("Unexpected opening state: {}", channel_id); } ChannelState::Established(ref mut session) => { match frame.performative() { Frame::Attach(attach) => { let cell = session.clone(); if !session.get_mut().handle_attach(attach, cell) { return Poll::Ready(Some(Ok(frame))); } } Frame::Flow(_) | Frame::Detach(_) => { return Poll::Ready(Some(Ok(frame))); } Frame::End(remote_end) => { trace!("Remote session end: {}", frame.channel_id()); let end = End { error: None }; session.get_mut().set_error( AmqpTransportError::SessionEnded( remote_end.error.clone(), ), ); let id = session.get_mut().id(); inner.post_frame(AmqpFrame::new(id, end.into())); inner.sessions.remove(channel_id); inner.sessions_map.remove(&frame.channel_id()); } _ => session.get_mut().handle_frame(frame.into_parts().1), } } ChannelState::Closing(ref mut tx) => match frame.performative() { Frame::End(_) => { if let Some(tx) = tx.take() { let _ = tx.send(Ok(())); } inner.sessions.remove(channel_id); inner.sessions_map.remove(&frame.channel_id()); } frm => trace!("Got frame after initiated session end: {:?}", frm), }, } } else { error!("Can not find channel: {}", channel_id); continue; } } Poll::Ready(None) => { inner.set_error(AmqpTransportError::Disconnected); return Poll::Ready(None); } Poll::Pending => { self.hb.update_local(update); break; } Poll::Ready(Some(Err(e))) => { trace!("error reading: {:?}", e); inner.set_error(e.clone().into()); return Poll::Ready(Some(Err(e.into()))); } } } Poll::Pending } } impl Drop for Connection { fn drop(&mut self) { self.inner .get_mut() .set_error(AmqpTransportError::Disconnected); } } impl Future for Connection { type Output = Result<(), AmqpCodecError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { // connection heartbeat match self.hb.poll(cx) { Ok(act) => match act { HeartbeatAction::None => (), HeartbeatAction::Close => { self.inner.get_mut().set_error(AmqpTransportError::Timeout); return Poll::Ready(Ok(())); } HeartbeatAction::Heartbeat => { self.inner .get_mut() .write_queue .push_back(AmqpFrame::new(0, Frame::Empty)); } }, Err(e) => { self.inner.get_mut().set_error(e); return Poll::Ready(Ok(())); } } loop { match self.poll_incoming(cx) { Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Ready(Some(Ok(frame))) => { if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) { if let ChannelState::Established(ref session) = channel { session.get_mut().handle_frame(frame.into_parts().1); continue; } } warn!("Unexpected frame: {:?}", frame); } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Pending => break, } } let _ = self.poll_outgoing(cx)?; self.register_write_task(cx); match self.poll_incoming(cx) { Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Ready(Some(Ok(frame))) => { if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) { if let ChannelState::Established(ref session) = channel { session.get_mut().handle_frame(frame.into_parts().1); return Poll::Pending; } } warn!("Unexpected frame: {:?}", frame); } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Pending => (), } Poll::Pending } } #[derive(Clone)] pub struct ConnectionController(pub(crate) Cell); impl ConnectionController { pub(crate) fn new(local: Configuration) -> ConnectionController { ConnectionController(Cell::new(ConnectionInner { local, remote: Configuration::default(), write_queue: VecDeque::new(), write_task: LocalWaker::new(), sessions: slab::Slab::with_capacity(8), sessions_map: FxHashMap::default(), error: None, state: State::Normal, })) } pub(crate) fn set_remote(&mut self, remote: Configuration) { self.0.get_mut().remote = remote; } #[inline] /// Get remote connection configuration pub fn remote_config(&self) -> &Configuration { &self.0.get_ref().remote } #[inline] /// Drop connection pub fn drop_connection(&mut self) { let inner = self.0.get_mut(); inner.state = State::Drop; inner.write_task.wake() } pub(crate) fn post_frame(&mut self, frame: AmqpFrame) { self.0.get_mut().post_frame(frame) } pub(crate) fn drop_session_copy(&mut self, _id: usize) {} } impl ConnectionInner { pub(crate) fn new(local: Configuration, remote: Configuration) -> ConnectionInner { ConnectionInner { local, remote, write_queue: VecDeque::new(), write_task: LocalWaker::new(), sessions: slab::Slab::with_capacity(8), sessions_map: FxHashMap::default(), error: None, state: State::Normal, } } fn set_error(&mut self, err: AmqpTransportError) { for (_, channel) in self.sessions.iter_mut() { match channel { ChannelState::Opening(_, _) | ChannelState::Closing(_) => (), ChannelState::Established(ref mut ses) => { ses.get_mut().set_error(err.clone()); } } } self.sessions.clear(); self.sessions_map.clear(); self.error = Some(err); } fn pop_next_frame(&mut self) -> Option { self.write_queue.pop_front() } fn post_frame(&mut self, frame: AmqpFrame) { // trace!("POST-FRAME: {:#?}", frame.performative()); self.write_queue.push_back(frame); self.write_task.wake(); } fn complete_session_creation(&mut self, channel_id: u16, begin: &Begin) { trace!( "session opened: {:?} {:?}", channel_id, begin.remote_channel() ); let id = begin.remote_channel().unwrap() as usize; if let Some(channel) = self.sessions.get_mut(id) { if channel.is_opening() { if let ChannelState::Opening(tx, cell) = channel { let cell = cell.upgrade().unwrap(); let session = Cell::new(SessionInner::new( id, true, ConnectionController(cell), channel_id, begin.next_outgoing_id(), begin.incoming_window(), begin.outgoing_window(), )); self.sessions_map.insert(channel_id, id); if tx .take() .unwrap() .send(Session::new(session.clone())) .is_err() { // todo: send end session } *channel = ChannelState::Established(session) } } else { // send error response } } else { // todo: rogue begin right now - do nothing. in future might indicate incoming attach } } }