From 6c39b4b13882628b9e3649d9a69b6f24d1a07aa2 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Tue, 12 Jul 2022 02:33:54 +0100 Subject: [PATCH] reduce boilerplate on response channels --- websockets/chat-actorless/src/command.rs | 32 ---- websockets/chat-actorless/src/handler.rs | 91 +++-------- websockets/chat-actorless/src/main.rs | 13 +- websockets/chat-actorless/src/server.rs | 183 ++++++++++++++++++----- 4 files changed, 176 insertions(+), 143 deletions(-) delete mode 100644 websockets/chat-actorless/src/command.rs diff --git a/websockets/chat-actorless/src/command.rs b/websockets/chat-actorless/src/command.rs deleted file mode 100644 index 5781f926..00000000 --- a/websockets/chat-actorless/src/command.rs +++ /dev/null @@ -1,32 +0,0 @@ -use tokio::sync::{mpsc, oneshot}; - -use crate::{ConnId, Msg, RoomId}; - -#[derive(Debug)] -pub enum Command { - Connect { - conn_tx: mpsc::UnboundedSender, - res_tx: oneshot::Sender, - }, - - Disconnect { - conn: ConnId, - }, - - List { - res_tx: oneshot::Sender>, - }, - - Join { - conn: ConnId, - room: RoomId, - res_tx: oneshot::Sender<()>, - }, - - Message { - room: RoomId, - msg: Msg, - skip: ConnId, - res_tx: oneshot::Sender<()>, - }, -} diff --git a/websockets/chat-actorless/src/handler.rs b/websockets/chat-actorless/src/handler.rs index f89ba8cb..d9e71047 100644 --- a/websockets/chat-actorless/src/handler.rs +++ b/websockets/chat-actorless/src/handler.rs @@ -5,16 +5,9 @@ use futures_util::{ future::{select, Either}, StreamExt as _, }; -use tokio::{ - pin, - sync::{ - mpsc::{self, UnboundedSender}, - oneshot, - }, - time::interval, -}; +use tokio::{pin, sync::mpsc, time::interval}; -use crate::{Command, ConnId, RoomId}; +use crate::{ChatServerHandle, ConnId}; /// How often heartbeat pings are sent const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -25,27 +18,20 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); /// Echo text & binary messages received from the client, respond to ping messages, and monitor /// connection health to detect network issues and free up resources. pub async fn chat_ws( - server_tx: UnboundedSender, + chat_server: ChatServerHandle, mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream, ) { log::info!("connected"); let mut name = None; - let mut room = "main".to_owned(); let mut last_heartbeat = Instant::now(); let mut interval = interval(HEARTBEAT_INTERVAL); let (conn_tx, mut conn_rx) = mpsc::unbounded_channel(); - let (res_tx, res_rx) = oneshot::channel(); // unwrap: chat server is not dropped before the HTTP server - server_tx - .send(Command::Connect { conn_tx, res_tx }) - .unwrap(); - - // unwrap: chat server does not drop our response channel - let conn_id = res_rx.await.unwrap(); + let conn_id = chat_server.connect(conn_tx).await; let close_reason = loop { // most of the futures we process need to be stack-pinned to work with select() @@ -56,6 +42,7 @@ pub async fn chat_ws( let msg_rx = conn_rx.recv(); pin!(msg_rx); + // TODO: nested select is pretty gross for readability on the match let messages = select(msg_stream.next(), msg_rx); pin!(messages); @@ -76,15 +63,8 @@ pub async fn chat_ws( } Message::Text(text) => { - process_text_msg( - &server_tx, - &mut session, - &text, - conn_id, - &mut room, - &mut name, - ) - .await; + process_text_msg(&chat_server, &mut session, &text, conn_id, &mut name) + .await; } Message::Binary(_bin) => { @@ -113,8 +93,10 @@ pub async fn chat_ws( session.text(chat_msg).await.unwrap(); } - // all connection's msg senders were dropped - Either::Left((Either::Right((None, _)), _)) => unreachable!(), + // all connection's message senders were dropped + Either::Left((Either::Right((None, _)), _)) => unreachable!( + "all connection message senders were dropped; chat server may have panicked" + ), // heartbeat internal tick Either::Right((_inst, _)) => { @@ -132,16 +114,17 @@ pub async fn chat_ws( }; }; + chat_server.disconnect(conn_id); + // attempt to close connection gracefully let _ = session.close(close_reason).await; } async fn process_text_msg( - server_tx: &UnboundedSender, + chat_server: &ChatServerHandle, session: &mut actix_ws::Session, text: &str, conn: ConnId, - room: &mut RoomId, name: &mut Option, ) { // strip leading and trailing whitespace (spaces, newlines, etc.) @@ -154,14 +137,9 @@ async fn process_text_msg( // unwrap: we have guaranteed non-zero string length already match cmd_args.next().unwrap() { "/list" => { - // Send ListRooms message to chat server and wait for - // response - log::info!("List rooms"); + log::info!("conn {conn}: listing rooms"); - let (res_tx, res_rx) = oneshot::channel(); - server_tx.send(Command::List { res_tx }).unwrap(); - // unwrap: chat server does not drop our response channel - let rooms = res_rx.await.unwrap(); + let rooms = chat_server.list_rooms().await; for room in rooms { session.text(room).await.unwrap(); @@ -169,24 +147,14 @@ async fn process_text_msg( } "/join" => match cmd_args.next() { - Some(room_id) => { - *room = room_id.to_owned(); + Some(room) => { + log::info!("conn {conn}: joining room {room}"); - let (res_tx, res_rx) = oneshot::channel(); + chat_server.join_room(conn, room).await; - server_tx - .send(Command::Join { - conn, - room: room.clone(), - res_tx, - }) - .unwrap(); - - // unwrap: chat server does not drop our response channel - res_rx.await.unwrap(); - - session.text(format!("joined {room_id}")).await.unwrap(); + session.text(format!("joined {room}")).await.unwrap(); } + None => { session.text("!!! room name is required").await.unwrap(); } @@ -194,6 +162,7 @@ async fn process_text_msg( "/name" => match cmd_args.next() { Some(new_name) => { + log::info!("conn {conn}: setting name to: {new_name}"); name.replace(new_name.to_owned()); } None => { @@ -215,20 +184,6 @@ async fn process_text_msg( None => msg.to_owned(), }; - let (res_tx, res_rx) = oneshot::channel(); - - // send message to chat server - server_tx - .send(Command::Message { - msg, - room: room.clone(), - skip: conn, - res_tx, - }) - // unwrap: chat server is not dropped before the HTTP server - .unwrap(); - - // unwrap: chat server does not drop our response channel - res_rx.await.unwrap(); + chat_server.send_message(conn, msg).await } } diff --git a/websockets/chat-actorless/src/main.rs b/websockets/chat-actorless/src/main.rs index 9342e2dd..b5890da9 100644 --- a/websockets/chat-actorless/src/main.rs +++ b/websockets/chat-actorless/src/main.rs @@ -5,17 +5,14 @@ use actix_files::NamedFile; use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder}; use tokio::{ - sync::mpsc::UnboundedSender, task::{spawn, spawn_local}, try_join, }; -mod command; mod handler; mod server; -pub use self::command::Command; -pub use self::server::ChatServer; +pub use self::server::{ChatServer, ChatServerHandle}; /// Connection ID. pub type ConnId = usize; @@ -34,12 +31,16 @@ async fn index() -> impl Responder { async fn chat_ws( req: HttpRequest, stream: web::Payload, - server_tx: web::Data>, + chat_server: web::Data, ) -> Result { let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; // spawn websocket handler (and don't await it) so that the response is returned immediately - spawn_local(handler::chat_ws((**server_tx).clone(), session, msg_stream)); + spawn_local(handler::chat_ws( + (**chat_server).clone(), + session, + msg_stream, + )); Ok(res) } diff --git a/websockets/chat-actorless/src/server.rs b/websockets/chat-actorless/src/server.rs index c7fbcc83..fd917251 100644 --- a/websockets/chat-actorless/src/server.rs +++ b/websockets/chat-actorless/src/server.rs @@ -10,11 +10,44 @@ use std::{ }; use rand::{thread_rng, Rng as _}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; -use crate::{Command, ConnId, Msg, RoomId}; +use crate::{ConnId, Msg, RoomId}; + +/// A command received by the [`ChatServer`]. +#[derive(Debug)] +enum Command { + Connect { + conn_tx: mpsc::UnboundedSender, + res_tx: oneshot::Sender, + }, + + Disconnect { + conn: ConnId, + }, + + List { + res_tx: oneshot::Sender>, + }, + + Join { + conn: ConnId, + room: RoomId, + res_tx: oneshot::Sender<()>, + }, + + Message { + msg: Msg, + conn: ConnId, + res_tx: oneshot::Sender<()>, + }, +} /// A multi-room chat server. +/// +/// Contains the logic of how connections chat with each other plus room management. +/// +/// Call and spawn [`run`](Self::run) to start processing commands. #[derive(Debug)] pub struct ChatServer { /// Map of connection IDs to their message receivers. @@ -27,41 +60,39 @@ pub struct ChatServer { visitor_count: Arc, /// Command receiver. - rx: mpsc::UnboundedReceiver, + cmd_rx: mpsc::UnboundedReceiver, } impl ChatServer { - pub fn new() -> (Self, mpsc::UnboundedSender) { + pub fn new() -> (Self, ChatServerHandle) { // create empty server let mut rooms = HashMap::with_capacity(4); // create default room rooms.insert("main".to_owned(), HashSet::new()); - let (tx, rx) = mpsc::unbounded_channel(); + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); ( Self { sessions: HashMap::new(), rooms, visitor_count: Arc::new(AtomicUsize::new(0)), - rx, + cmd_rx, }, - tx, + ChatServerHandle { cmd_tx }, ) } -} -impl ChatServer { - /// Send message to all users in the room. + /// Send message to users in a room. /// - /// `skip_id` is used to prevent messages send by a connection also being received by it. - async fn send_message(&self, room: &str, msg: impl Into, skip_id: ConnId) { + /// `skip` is used to prevent messages triggered by a connection also being received by it. + async fn send_system_message(&self, room: &str, skip: ConnId, msg: impl Into) { if let Some(sessions) = self.rooms.get(room) { let msg = msg.into(); for conn_id in sessions { - if *conn_id != skip_id { + if *conn_id != skip { if let Some(tx) = self.sessions.get(conn_id) { // errors if client disconnected abruptly and hasn't been timed-out yet let _ = tx.send(msg.clone()); @@ -71,14 +102,26 @@ impl ChatServer { } } - /// Handler for Connect message. + /// Send message to all other users in current room. /// - /// Register new session and assign unique id to this session + /// `conn` is used to find current room and prevent messages sent by a connection also being + /// received by it. + async fn send_message(&self, conn: ConnId, msg: impl Into) { + if let Some(room) = self + .rooms + .iter() + .find_map(|(room, participants)| participants.contains(&conn).then_some(room)) + { + self.send_system_message(&room, conn, msg).await; + }; + } + + /// Register new session and assign unique ID to this session async fn connect(&mut self, tx: mpsc::UnboundedSender) -> ConnId { log::info!("Someone joined"); // notify all users in same room - self.send_message("main", "Someone joined", 0).await; + self.send_system_message("main", 0, "Someone joined").await; // register session with random connection ID let id = thread_rng().gen::(); @@ -91,14 +134,14 @@ impl ChatServer { .insert(id); let count = self.visitor_count.fetch_add(1, Ordering::SeqCst); - self.send_message("main", format!("Total visitors {count}"), 0) + self.send_system_message("main", 0, format!("Total visitors {count}")) .await; // send id back id } - /// Handler for Disconnect message. + /// Unregister connection from room map and broadcast disconnection message. async fn disconnect(&mut self, conn_id: ConnId) { println!("Someone disconnected"); @@ -116,19 +159,14 @@ impl ChatServer { // send message to other users for room in rooms { - self.send_message(&room, "Someone disconnected", 0).await; + self.send_system_message(&room, 0, "Someone disconnected") + .await; } } - /// Handler for `ListRooms` message. + /// Returns list of created room names. fn list_rooms(&mut self) -> Vec { - let mut rooms = Vec::new(); - - for key in self.rooms.keys() { - rooms.push(key.to_owned()) - } - - rooms + self.rooms.keys().cloned().collect() } /// Join room, send disconnect message to old room send join message to new room. @@ -143,7 +181,8 @@ impl ChatServer { } // send message to other users for room in rooms { - self.send_message(&room, "Someone disconnected", 0).await; + self.send_system_message(&room, 0, "Someone disconnected") + .await; } self.rooms @@ -151,12 +190,13 @@ impl ChatServer { .or_insert_with(HashSet::new) .insert(conn_id); - self.send_message(&room, "Someone connected", conn_id).await; + self.send_system_message(&room, conn_id, "Someone connected") + .await; } pub async fn run(mut self) -> io::Result<()> { loop { - let cmd = match self.rx.recv().await { + let cmd = match self.cmd_rx.recv().await { Some(cmd) => cmd, None => break, }; @@ -180,13 +220,8 @@ impl ChatServer { let _ = res_tx.send(()); } - Command::Message { - room, - msg, - skip, - res_tx, - } => { - self.send_message(&room, msg, skip).await; + Command::Message { conn, msg, res_tx } => { + self.send_message(conn, msg).await; let _ = res_tx.send(()); } } @@ -195,3 +230,77 @@ impl ChatServer { Ok(()) } } + +/// Handle and command sender for chat server. +/// +/// Reduces boilerplate of setting up response channels in WebSocket handlers. +#[derive(Debug, Clone)] +pub struct ChatServerHandle { + cmd_tx: mpsc::UnboundedSender, +} + +impl ChatServerHandle { + /// Register client message sender and obtain connection ID. + pub async fn connect(&self, conn_tx: mpsc::UnboundedSender) -> ConnId { + let (res_tx, res_rx) = oneshot::channel(); + + // unwrap: chat server should not have been dropped + self.cmd_tx + .send(Command::Connect { conn_tx, res_tx }) + .unwrap(); + + // unwrap: chat server does not drop out response channel + res_rx.await.unwrap() + } + + /// List all created rooms. + pub async fn list_rooms(&self) -> Vec { + let (res_tx, res_rx) = oneshot::channel(); + + // unwrap: chat server should not have been dropped + self.cmd_tx.send(Command::List { res_tx }).unwrap(); + + // unwrap: chat server does not drop our response channel + res_rx.await.unwrap() + } + + /// Join `room`, creating it if it does not exist. + pub async fn join_room(&self, conn: ConnId, room: impl Into) { + let (res_tx, res_rx) = oneshot::channel(); + + // unwrap: chat server should not have been dropped + self.cmd_tx + .send(Command::Join { + conn, + room: room.into(), + res_tx, + }) + .unwrap(); + + // unwrap: chat server does not drop our response channel + res_rx.await.unwrap(); + } + + /// Broadcast message to current room. + pub async fn send_message(&self, conn: ConnId, msg: impl Into) { + let (res_tx, res_rx) = oneshot::channel(); + + // unwrap: chat server should not have been dropped + self.cmd_tx + .send(Command::Message { + msg: msg.into(), + conn, + res_tx, + }) + .unwrap(); + + // unwrap: chat server does not drop our response channel + res_rx.await.unwrap(); + } + + /// Unregister message sender and broadcast disconnection message to current room. + pub fn disconnect(&self, conn: ConnId) { + // unwrap: chat server should not have been dropped + self.cmd_tx.send(Command::Disconnect { conn }).unwrap(); + } +}