From 617a0b8338e61b1a3625f0cc7fa4f543cb23d701 Mon Sep 17 00:00:00 2001 From: Jomar Milan Date: Sat, 20 Jun 2026 10:46:02 -0700 Subject: Use separate function to handle play messages Also remove Arc from the HashMap --- src/main.rs | 20 ++--- src/play.rs | 276 ++++++++++++++++++++++++++++++++++----------------------- src/session.rs | 6 +- 3 files changed, 178 insertions(+), 124 deletions(-) diff --git a/src/main.rs b/src/main.rs index c05d3c6..b423f6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,16 +4,12 @@ //! access to players to access and manage the contents of their hands from without using the game //! screen by using a web browser. -#![warn( - missing_docs, - missing_debug_implementations -)] +#![warn(missing_docs, missing_debug_implementations)] mod play; mod session; mod template; -use std::array; use crate::play::handle_play; use crate::session::{HandObject, PlayerColor, Session}; use crate::template::{IndexTemplate, SessionTemplate}; @@ -24,6 +20,7 @@ use axum::response::{ErrorResponse, Html, IntoResponse, Redirect, Response}; use axum::routing::{any, get, put}; use axum::{Json, Router}; use rust_embed::Embed; +use std::array; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Mutex, RwLock}; @@ -32,15 +29,14 @@ use std::sync::{Arc, Mutex, RwLock}; #[folder = "assets/"] struct EmbedAsset; +#[derive(Default)] struct AppState { - sessions: RwLock>>>, + sessions: RwLock>>, } impl AppState { fn new() -> Self { - AppState { - sessions: RwLock::new(HashMap::new()), - } + Self::default() } } @@ -120,7 +116,7 @@ async fn create_session( let mut sessions = state.sessions.write().unwrap(); let session = Session::new(name); - sessions.insert(id, Arc::new(Mutex::new(session))); + sessions.insert(id, Mutex::new(session)); StatusCode::CREATED } @@ -130,7 +126,7 @@ async fn update_hands( State(state): State>, Json(payload): Json>>, ) -> StatusCode { - let mut sessions = state.sessions.write().unwrap(); + let sessions = state.sessions.read().unwrap(); let hand = array::from_fn(|i| { let color = PlayerColor::try_from(i); @@ -144,7 +140,7 @@ async fn update_hands( } }); - match sessions.get_mut(&id) { + match sessions.get(&id) { Some(session) => { let mut session = session.lock().unwrap(); diff --git a/src/play.rs b/src/play.rs index 2f364c0..998ac29 100644 --- a/src/play.rs +++ b/src/play.rs @@ -1,13 +1,11 @@ use crate::AppState; -use crate::session::{HandObject, PlayerColor, Session}; +use crate::session::{HandObject, PlayerColor}; use axum::extract::ws::{Message, Utf8Bytes, WebSocket}; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::sync::{Arc, Mutex, RwLock, Weak}; -use tokio::sync::broadcast::Receiver; -use tokio::sync::broadcast::error::RecvError; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; +use std::mem; +use std::sync::{Arc, Mutex}; +use tokio::sync::{broadcast, mpsc, oneshot}; #[derive(Deserialize)] enum IncomingPlayMessage { @@ -20,7 +18,6 @@ enum IncomingPlayMessage { enum OutgoingPlayMessage { Initialize { colors: Vec }, Hand(Vec), - // TODO: include error details Error, } @@ -29,10 +26,49 @@ pub enum PlayUpdate { HandUpdate([Vec; PlayerColor::COUNT]), } +// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code) +#[allow(dead_code)] +pub enum Error { + BadJson(serde_json::Error), + Closed, + InvalidSession(String), + InvalidColor, +} + +struct PlayState { + session: Option, + color: PlayerColor, + update_cancel_tx: oneshot::Sender<()>, +} +impl PlayState { + pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self { + Self { + session: Default::default(), + color: Default::default(), + update_cancel_tx, + } + } +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Self::BadJson(value) + } +} + +impl From> for Error { + fn from(_: mpsc::error::SendError) -> Self { + Self::Closed + } +} + pub async fn handle_play(socket: WebSocket, app_state: Arc) { let (mut sender, mut receiver) = socket.split(); let (sender_tx, mut sender_rx) = mpsc::channel(2); + let (update_cancel_tx, _) = oneshot::channel(); + let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx))); + let mut send_task = tokio::spawn(async move { while let Some(message) = sender_rx.recv().await { let serialized = match serde_json::to_string(&message) { @@ -53,101 +89,42 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc) { }); let mut recv_task = { - let sender_tx = sender_tx.clone(); tokio::spawn(async move { - let mut player_session = None; - let player_color = Arc::new(RwLock::new(PlayerColor::Grey)); - while let Some(msg) = receiver.next().await { let Ok(Message::Text(text)) = msg else { continue; }; match serde_json::from_str(text.as_str()) { - Ok(IncomingPlayMessage::Initialize { id }) => { - let session = { - let sessions = app_state.sessions.read().unwrap(); - sessions - .get(&id) - .map(Arc::clone) - .ok_or("Session did not exist") - }; - - match session { - Ok(session) => { - let (colors, update_rx) = { - let session = session.lock().unwrap(); - - let colors: Vec = session.seats.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) - .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) - .map(|color| String::from(color.as_ref())) - .collect(); - let update_rx = session.update_tx.subscribe(); - - (colors, update_rx) - }; - - player_session = Some(Arc::downgrade(&session)); - { - let sender_tx = sender_tx.clone(); - let player_session = Arc::downgrade(&session); - let player_color = player_color.clone(); - - tokio::spawn(async move { - handle_update( - update_rx, - sender_tx, - player_session, - player_color, - ) - .await - }); - } - - let response = OutgoingPlayMessage::Initialize { colors }; - if sender_tx.send(response).await.is_err() { - break; - } + Ok(msg) => { + let result = + handle_play_message(msg, sender_tx.clone(), &state, &app_state).await; + match result { + Ok(_) => (), + Err(Error::Closed) => { + eprintln!("Failed to send play message as the channel closed."); + break; } - Err(err) => { - eprintln!("Failed to access session: {}", err); - let response = OutgoingPlayMessage::Error; - if sender_tx.send(response).await.is_err() { + Err(_) => { + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!( + "Failed to send play message as the channel closed: {}", + err + ); break; } } } } - Ok(IncomingPlayMessage::Color(color)) => { - let Some(session) = - player_session.clone().and_then(|session| session.upgrade()) - else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - let Ok(color) = PlayerColor::try_from(color.as_str()) else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - - let hand = session.lock().unwrap().seats[&color].clone(); - *player_color.write().unwrap() = color; - if sender_tx - .send(OutgoingPlayMessage::Hand(hand)) - .await - .is_err() - { + Err(_) => { + // TODO: include error details + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!("Failed to send play message as the channel closed: {}", err); break; } } - Err(err) => { - eprintln!( - "Encountered an error while handling a message from a player: {}", - err - ); - break; - } } } }) @@ -159,33 +136,114 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc) { } } -async fn handle_update( - mut update_rx: Receiver, - sender_tx: Sender, - _player_session: Weak>, - player_color: Arc>, -) { - loop { - match update_rx.recv().await { - Ok(PlayUpdate::HandUpdate(hands)) => { - let colors: Vec = hands.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) +async fn handle_play_message( + message: IncomingPlayMessage, + sender_tx: mpsc::Sender, + state: &Arc>, + app_state: &Arc, +) -> Result<(), Error> { + match message { + IncomingPlayMessage::Initialize { id } => { + let (colors, update_rx) = { + let sessions = app_state.sessions.read().unwrap(); + let session = sessions.get(&id).map(|session| session.lock().unwrap()); + // The Option is unwrapped with let else here instead of ok_or and propagating + // because the error moves id which would be used later, and the compiler does not + // reason about propagation returning early. + let Some(session) = session else { + return Err(Error::InvalidSession(id)); + }; + + let colors: Vec = session + .seats + .iter() + .enumerate() + .filter(|(_, hand)| !hand.is_empty()) .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) .map(|color| String::from(color.as_ref())) .collect(); - let _ = sender_tx - .send(OutgoingPlayMessage::Initialize { - colors, - }) - .await; - let hand = { - let color = player_color.read().unwrap(); - hands[usize::from(&*color)].to_owned() - }; - let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + let update_rx = session.update_tx.subscribe(); + + (colors, update_rx) + }; + + let update_cancel_rx = { + let mut state = state.lock().unwrap(); + let (update_cancel_tx, update_cancel_rx) = oneshot::channel(); + let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(()); + state.session = Some(id); + + update_cancel_rx + }; + { + let sender_tx = sender_tx.clone(); + let state = state.clone(); + + tokio::spawn(async move { + handle_update(update_rx, update_cancel_rx, sender_tx, state).await; + }); } - Err(RecvError::Closed) => break, - Err(RecvError::Lagged(_)) => continue, + + sender_tx + .send(OutgoingPlayMessage::Initialize { colors }) + .await + .map_err(Error::from) + } + IncomingPlayMessage::Color(color) => { + let hand = { + let mut state = state.lock().unwrap(); + state.color = + PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?; + let name = state + .session + .as_ref() + .ok_or_else(|| Error::InvalidSession(String::default()))?; + let sessions = app_state.sessions.read().unwrap(); + let session = sessions + .get(name) + .ok_or(Error::InvalidSession(name.clone()))?; + + session.lock().unwrap().seats[&state.color].clone() + }; + + sender_tx + .send(OutgoingPlayMessage::Hand(hand)) + .await + .map_err(Error::from) + } + } +} + +async fn handle_update( + mut update_rx: broadcast::Receiver, + mut cancel_rx: oneshot::Receiver<()>, + sender_tx: mpsc::Sender, + state: Arc>, +) { + loop { + tokio::select! { + update = update_rx.recv() => match update { + Ok(PlayUpdate::HandUpdate(hands)) => { + let colors: Vec = hands.iter().enumerate() + .filter(|(_, hand)| !hand.is_empty()) + .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) + .map(|color| String::from(color.as_ref())) + .collect(); + let _ = sender_tx + .send(OutgoingPlayMessage::Initialize { + colors, + }) + .await; + let hand = { + let color = &state.lock().unwrap().color; + hands[color].to_owned() + }; + let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }, + _ = &mut cancel_rx => break, } } } diff --git a/src/session.rs b/src/session.rs index 096ae1a..f4c8fe3 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,6 @@ -use std::array; use crate::play::PlayUpdate; use serde::{Deserialize, Serialize}; +use std::array; use std::ops::Index; use tokio::sync::broadcast; @@ -153,7 +153,7 @@ impl TryFrom<&str> for PlayerColor { "Pink" => Ok(Self::Pink), "Grey" => Ok(Self::Grey), "Black" => Ok(Self::Black), - _ => Err(()) + _ => Err(()), } } } @@ -175,7 +175,7 @@ impl TryFrom for PlayerColor { 9 => Ok(Self::Pink), 10 => Ok(Self::Grey), 11 => Ok(Self::Black), - _ => Err(()) + _ => Err(()), } } } -- cgit v1.2.3