1
0
mirror of https://github.com/actix/examples synced 2024-11-23 22:41:07 +01:00

reduce boilerplate on response channels

This commit is contained in:
Rob Ede 2022-07-12 02:33:54 +01:00
parent 25368e6b65
commit 6c39b4b138
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
4 changed files with 176 additions and 143 deletions

View File

@ -1,32 +0,0 @@
use tokio::sync::{mpsc, oneshot};
use crate::{ConnId, Msg, RoomId};
#[derive(Debug)]
pub enum Command {
Connect {
conn_tx: mpsc::UnboundedSender<Msg>,
res_tx: oneshot::Sender<ConnId>,
},
Disconnect {
conn: ConnId,
},
List {
res_tx: oneshot::Sender<Vec<RoomId>>,
},
Join {
conn: ConnId,
room: RoomId,
res_tx: oneshot::Sender<()>,
},
Message {
room: RoomId,
msg: Msg,
skip: ConnId,
res_tx: oneshot::Sender<()>,
},
}

View File

@ -5,16 +5,9 @@ use futures_util::{
future::{select, Either}, future::{select, Either},
StreamExt as _, StreamExt as _,
}; };
use tokio::{ use tokio::{pin, sync::mpsc, time::interval};
pin,
sync::{
mpsc::{self, UnboundedSender},
oneshot,
},
time::interval,
};
use crate::{Command, ConnId, RoomId}; use crate::{ChatServerHandle, ConnId};
/// How often heartbeat pings are sent /// How often heartbeat pings are sent
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -25,27 +18,20 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Echo text & binary messages received from the client, respond to ping messages, and monitor /// Echo text & binary messages received from the client, respond to ping messages, and monitor
/// connection health to detect network issues and free up resources. /// connection health to detect network issues and free up resources.
pub async fn chat_ws( pub async fn chat_ws(
server_tx: UnboundedSender<Command>, chat_server: ChatServerHandle,
mut session: actix_ws::Session, mut session: actix_ws::Session,
mut msg_stream: actix_ws::MessageStream, mut msg_stream: actix_ws::MessageStream,
) { ) {
log::info!("connected"); log::info!("connected");
let mut name = None; let mut name = None;
let mut room = "main".to_owned();
let mut last_heartbeat = Instant::now(); let mut last_heartbeat = Instant::now();
let mut interval = interval(HEARTBEAT_INTERVAL); let mut interval = interval(HEARTBEAT_INTERVAL);
let (conn_tx, mut conn_rx) = mpsc::unbounded_channel(); let (conn_tx, mut conn_rx) = mpsc::unbounded_channel();
let (res_tx, res_rx) = oneshot::channel();
// unwrap: chat server is not dropped before the HTTP server // unwrap: chat server is not dropped before the HTTP server
server_tx let conn_id = chat_server.connect(conn_tx).await;
.send(Command::Connect { conn_tx, res_tx })
.unwrap();
// unwrap: chat server does not drop our response channel
let conn_id = res_rx.await.unwrap();
let close_reason = loop { let close_reason = loop {
// most of the futures we process need to be stack-pinned to work with select() // most of the futures we process need to be stack-pinned to work with select()
@ -56,6 +42,7 @@ pub async fn chat_ws(
let msg_rx = conn_rx.recv(); let msg_rx = conn_rx.recv();
pin!(msg_rx); pin!(msg_rx);
// TODO: nested select is pretty gross for readability on the match
let messages = select(msg_stream.next(), msg_rx); let messages = select(msg_stream.next(), msg_rx);
pin!(messages); pin!(messages);
@ -76,14 +63,7 @@ pub async fn chat_ws(
} }
Message::Text(text) => { Message::Text(text) => {
process_text_msg( process_text_msg(&chat_server, &mut session, &text, conn_id, &mut name)
&server_tx,
&mut session,
&text,
conn_id,
&mut room,
&mut name,
)
.await; .await;
} }
@ -113,8 +93,10 @@ pub async fn chat_ws(
session.text(chat_msg).await.unwrap(); session.text(chat_msg).await.unwrap();
} }
// all connection's msg senders were dropped // all connection's message senders were dropped
Either::Left((Either::Right((None, _)), _)) => unreachable!(), Either::Left((Either::Right((None, _)), _)) => unreachable!(
"all connection message senders were dropped; chat server may have panicked"
),
// heartbeat internal tick // heartbeat internal tick
Either::Right((_inst, _)) => { Either::Right((_inst, _)) => {
@ -132,16 +114,17 @@ pub async fn chat_ws(
}; };
}; };
chat_server.disconnect(conn_id);
// attempt to close connection gracefully // attempt to close connection gracefully
let _ = session.close(close_reason).await; let _ = session.close(close_reason).await;
} }
async fn process_text_msg( async fn process_text_msg(
server_tx: &UnboundedSender<Command>, chat_server: &ChatServerHandle,
session: &mut actix_ws::Session, session: &mut actix_ws::Session,
text: &str, text: &str,
conn: ConnId, conn: ConnId,
room: &mut RoomId,
name: &mut Option<String>, name: &mut Option<String>,
) { ) {
// strip leading and trailing whitespace (spaces, newlines, etc.) // strip leading and trailing whitespace (spaces, newlines, etc.)
@ -154,14 +137,9 @@ async fn process_text_msg(
// unwrap: we have guaranteed non-zero string length already // unwrap: we have guaranteed non-zero string length already
match cmd_args.next().unwrap() { match cmd_args.next().unwrap() {
"/list" => { "/list" => {
// Send ListRooms message to chat server and wait for log::info!("conn {conn}: listing rooms");
// response
log::info!("List rooms");
let (res_tx, res_rx) = oneshot::channel(); let rooms = chat_server.list_rooms().await;
server_tx.send(Command::List { res_tx }).unwrap();
// unwrap: chat server does not drop our response channel
let rooms = res_rx.await.unwrap();
for room in rooms { for room in rooms {
session.text(room).await.unwrap(); session.text(room).await.unwrap();
@ -169,24 +147,14 @@ async fn process_text_msg(
} }
"/join" => match cmd_args.next() { "/join" => match cmd_args.next() {
Some(room_id) => { Some(room) => {
*room = room_id.to_owned(); log::info!("conn {conn}: joining room {room}");
let (res_tx, res_rx) = oneshot::channel(); chat_server.join_room(conn, room).await;
server_tx session.text(format!("joined {room}")).await.unwrap();
.send(Command::Join {
conn,
room: room.clone(),
res_tx,
})
.unwrap();
// unwrap: chat server does not drop our response channel
res_rx.await.unwrap();
session.text(format!("joined {room_id}")).await.unwrap();
} }
None => { None => {
session.text("!!! room name is required").await.unwrap(); session.text("!!! room name is required").await.unwrap();
} }
@ -194,6 +162,7 @@ async fn process_text_msg(
"/name" => match cmd_args.next() { "/name" => match cmd_args.next() {
Some(new_name) => { Some(new_name) => {
log::info!("conn {conn}: setting name to: {new_name}");
name.replace(new_name.to_owned()); name.replace(new_name.to_owned());
} }
None => { None => {
@ -215,20 +184,6 @@ async fn process_text_msg(
None => msg.to_owned(), None => msg.to_owned(),
}; };
let (res_tx, res_rx) = oneshot::channel(); chat_server.send_message(conn, msg).await
// send message to chat server
server_tx
.send(Command::Message {
msg,
room: room.clone(),
skip: conn,
res_tx,
})
// unwrap: chat server is not dropped before the HTTP server
.unwrap();
// unwrap: chat server does not drop our response channel
res_rx.await.unwrap();
} }
} }

View File

@ -5,17 +5,14 @@
use actix_files::NamedFile; use actix_files::NamedFile;
use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder}; use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder};
use tokio::{ use tokio::{
sync::mpsc::UnboundedSender,
task::{spawn, spawn_local}, task::{spawn, spawn_local},
try_join, try_join,
}; };
mod command;
mod handler; mod handler;
mod server; mod server;
pub use self::command::Command; pub use self::server::{ChatServer, ChatServerHandle};
pub use self::server::ChatServer;
/// Connection ID. /// Connection ID.
pub type ConnId = usize; pub type ConnId = usize;
@ -34,12 +31,16 @@ async fn index() -> impl Responder {
async fn chat_ws( async fn chat_ws(
req: HttpRequest, req: HttpRequest,
stream: web::Payload, stream: web::Payload,
server_tx: web::Data<UnboundedSender<Command>>, chat_server: web::Data<ChatServerHandle>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
// spawn websocket handler (and don't await it) so that the response is returned immediately // spawn websocket handler (and don't await it) so that the response is returned immediately
spawn_local(handler::chat_ws((**server_tx).clone(), session, msg_stream)); spawn_local(handler::chat_ws(
(**chat_server).clone(),
session,
msg_stream,
));
Ok(res) Ok(res)
} }

View File

@ -10,11 +10,44 @@ use std::{
}; };
use rand::{thread_rng, Rng as _}; use rand::{thread_rng, Rng as _};
use tokio::sync::mpsc; use tokio::sync::{mpsc, oneshot};
use crate::{Command, ConnId, Msg, RoomId}; use crate::{ConnId, Msg, RoomId};
/// A command received by the [`ChatServer`].
#[derive(Debug)]
enum Command {
Connect {
conn_tx: mpsc::UnboundedSender<Msg>,
res_tx: oneshot::Sender<ConnId>,
},
Disconnect {
conn: ConnId,
},
List {
res_tx: oneshot::Sender<Vec<RoomId>>,
},
Join {
conn: ConnId,
room: RoomId,
res_tx: oneshot::Sender<()>,
},
Message {
msg: Msg,
conn: ConnId,
res_tx: oneshot::Sender<()>,
},
}
/// A multi-room chat server. /// A multi-room chat server.
///
/// Contains the logic of how connections chat with each other plus room management.
///
/// Call and spawn [`run`](Self::run) to start processing commands.
#[derive(Debug)] #[derive(Debug)]
pub struct ChatServer { pub struct ChatServer {
/// Map of connection IDs to their message receivers. /// Map of connection IDs to their message receivers.
@ -27,41 +60,39 @@ pub struct ChatServer {
visitor_count: Arc<AtomicUsize>, visitor_count: Arc<AtomicUsize>,
/// Command receiver. /// Command receiver.
rx: mpsc::UnboundedReceiver<Command>, cmd_rx: mpsc::UnboundedReceiver<Command>,
} }
impl ChatServer { impl ChatServer {
pub fn new() -> (Self, mpsc::UnboundedSender<Command>) { pub fn new() -> (Self, ChatServerHandle) {
// create empty server // create empty server
let mut rooms = HashMap::with_capacity(4); let mut rooms = HashMap::with_capacity(4);
// create default room // create default room
rooms.insert("main".to_owned(), HashSet::new()); rooms.insert("main".to_owned(), HashSet::new());
let (tx, rx) = mpsc::unbounded_channel(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
( (
Self { Self {
sessions: HashMap::new(), sessions: HashMap::new(),
rooms, rooms,
visitor_count: Arc::new(AtomicUsize::new(0)), visitor_count: Arc::new(AtomicUsize::new(0)),
rx, cmd_rx,
}, },
tx, ChatServerHandle { cmd_tx },
) )
} }
}
impl ChatServer { /// Send message to users in a room.
/// Send message to all users in the room.
/// ///
/// `skip_id` is used to prevent messages send by a connection also being received by it. /// `skip` is used to prevent messages triggered by a connection also being received by it.
async fn send_message(&self, room: &str, msg: impl Into<String>, skip_id: ConnId) { async fn send_system_message(&self, room: &str, skip: ConnId, msg: impl Into<String>) {
if let Some(sessions) = self.rooms.get(room) { if let Some(sessions) = self.rooms.get(room) {
let msg = msg.into(); let msg = msg.into();
for conn_id in sessions { for conn_id in sessions {
if *conn_id != skip_id { if *conn_id != skip {
if let Some(tx) = self.sessions.get(conn_id) { if let Some(tx) = self.sessions.get(conn_id) {
// errors if client disconnected abruptly and hasn't been timed-out yet // errors if client disconnected abruptly and hasn't been timed-out yet
let _ = tx.send(msg.clone()); let _ = tx.send(msg.clone());
@ -71,14 +102,26 @@ impl ChatServer {
} }
} }
/// Handler for Connect message. /// Send message to all other users in current room.
/// ///
/// Register new session and assign unique id to this session /// `conn` is used to find current room and prevent messages sent by a connection also being
/// received by it.
async fn send_message(&self, conn: ConnId, msg: impl Into<String>) {
if let Some(room) = self
.rooms
.iter()
.find_map(|(room, participants)| participants.contains(&conn).then_some(room))
{
self.send_system_message(&room, conn, msg).await;
};
}
/// Register new session and assign unique ID to this session
async fn connect(&mut self, tx: mpsc::UnboundedSender<Msg>) -> ConnId { async fn connect(&mut self, tx: mpsc::UnboundedSender<Msg>) -> ConnId {
log::info!("Someone joined"); log::info!("Someone joined");
// notify all users in same room // notify all users in same room
self.send_message("main", "Someone joined", 0).await; self.send_system_message("main", 0, "Someone joined").await;
// register session with random connection ID // register session with random connection ID
let id = thread_rng().gen::<usize>(); let id = thread_rng().gen::<usize>();
@ -91,14 +134,14 @@ impl ChatServer {
.insert(id); .insert(id);
let count = self.visitor_count.fetch_add(1, Ordering::SeqCst); let count = self.visitor_count.fetch_add(1, Ordering::SeqCst);
self.send_message("main", format!("Total visitors {count}"), 0) self.send_system_message("main", 0, format!("Total visitors {count}"))
.await; .await;
// send id back // send id back
id id
} }
/// Handler for Disconnect message. /// Unregister connection from room map and broadcast disconnection message.
async fn disconnect(&mut self, conn_id: ConnId) { async fn disconnect(&mut self, conn_id: ConnId) {
println!("Someone disconnected"); println!("Someone disconnected");
@ -116,19 +159,14 @@ impl ChatServer {
// send message to other users // send message to other users
for room in rooms { for room in rooms {
self.send_message(&room, "Someone disconnected", 0).await; self.send_system_message(&room, 0, "Someone disconnected")
.await;
} }
} }
/// Handler for `ListRooms` message. /// Returns list of created room names.
fn list_rooms(&mut self) -> Vec<String> { fn list_rooms(&mut self) -> Vec<String> {
let mut rooms = Vec::new(); self.rooms.keys().cloned().collect()
for key in self.rooms.keys() {
rooms.push(key.to_owned())
}
rooms
} }
/// Join room, send disconnect message to old room send join message to new room. /// Join room, send disconnect message to old room send join message to new room.
@ -143,7 +181,8 @@ impl ChatServer {
} }
// send message to other users // send message to other users
for room in rooms { for room in rooms {
self.send_message(&room, "Someone disconnected", 0).await; self.send_system_message(&room, 0, "Someone disconnected")
.await;
} }
self.rooms self.rooms
@ -151,12 +190,13 @@ impl ChatServer {
.or_insert_with(HashSet::new) .or_insert_with(HashSet::new)
.insert(conn_id); .insert(conn_id);
self.send_message(&room, "Someone connected", conn_id).await; self.send_system_message(&room, conn_id, "Someone connected")
.await;
} }
pub async fn run(mut self) -> io::Result<()> { pub async fn run(mut self) -> io::Result<()> {
loop { loop {
let cmd = match self.rx.recv().await { let cmd = match self.cmd_rx.recv().await {
Some(cmd) => cmd, Some(cmd) => cmd,
None => break, None => break,
}; };
@ -180,13 +220,8 @@ impl ChatServer {
let _ = res_tx.send(()); let _ = res_tx.send(());
} }
Command::Message { Command::Message { conn, msg, res_tx } => {
room, self.send_message(conn, msg).await;
msg,
skip,
res_tx,
} => {
self.send_message(&room, msg, skip).await;
let _ = res_tx.send(()); let _ = res_tx.send(());
} }
} }
@ -195,3 +230,77 @@ impl ChatServer {
Ok(()) Ok(())
} }
} }
/// Handle and command sender for chat server.
///
/// Reduces boilerplate of setting up response channels in WebSocket handlers.
#[derive(Debug, Clone)]
pub struct ChatServerHandle {
cmd_tx: mpsc::UnboundedSender<Command>,
}
impl ChatServerHandle {
/// Register client message sender and obtain connection ID.
pub async fn connect(&self, conn_tx: mpsc::UnboundedSender<String>) -> ConnId {
let (res_tx, res_rx) = oneshot::channel();
// unwrap: chat server should not have been dropped
self.cmd_tx
.send(Command::Connect { conn_tx, res_tx })
.unwrap();
// unwrap: chat server does not drop out response channel
res_rx.await.unwrap()
}
/// List all created rooms.
pub async fn list_rooms(&self) -> Vec<String> {
let (res_tx, res_rx) = oneshot::channel();
// unwrap: chat server should not have been dropped
self.cmd_tx.send(Command::List { res_tx }).unwrap();
// unwrap: chat server does not drop our response channel
res_rx.await.unwrap()
}
/// Join `room`, creating it if it does not exist.
pub async fn join_room(&self, conn: ConnId, room: impl Into<String>) {
let (res_tx, res_rx) = oneshot::channel();
// unwrap: chat server should not have been dropped
self.cmd_tx
.send(Command::Join {
conn,
room: room.into(),
res_tx,
})
.unwrap();
// unwrap: chat server does not drop our response channel
res_rx.await.unwrap();
}
/// Broadcast message to current room.
pub async fn send_message(&self, conn: ConnId, msg: impl Into<String>) {
let (res_tx, res_rx) = oneshot::channel();
// unwrap: chat server should not have been dropped
self.cmd_tx
.send(Command::Message {
msg: msg.into(),
conn,
res_tx,
})
.unwrap();
// unwrap: chat server does not drop our response channel
res_rx.await.unwrap();
}
/// Unregister message sender and broadcast disconnection message to current room.
pub fn disconnect(&self, conn: ConnId) {
// unwrap: chat server should not have been dropped
self.cmd_tx.send(Command::Disconnect { conn }).unwrap();
}
}