diff options
| author | Jomar Milan <jomarm@jomarm.com> | 2026-06-20 10:46:02 -0700 |
|---|---|---|
| committer | Jomar Milan <jomarm@jomarm.com> | 2026-06-20 10:46:02 -0700 |
| commit | 617a0b8338e61b1a3625f0cc7fa4f543cb23d701 (patch) | |
| tree | 10ac60ac155fe9d3c78dced1ccc739680d2ab059 | |
| parent | a83eb7452b5d6bded88dca083b59df833f6dbd2f (diff) | |
Use separate function to handle play messages
Also remove Arc from the HashMap
| -rw-r--r-- | src/main.rs | 20 | ||||
| -rw-r--r-- | src/play.rs | 276 | ||||
| -rw-r--r-- | src/session.rs | 6 |
3 files changed, 178 insertions, 124 deletions
diff --git a/src/main.rs b/src/main.rs index c05d3c6..b423f6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,16 +4,12 @@ //! access to players to access and manage the contents of their hands from without using the game //! screen by using a web browser. -#![warn( - missing_docs, - missing_debug_implementations -)] +#![warn(missing_docs, missing_debug_implementations)] mod play; mod session; mod template; -use std::array; use crate::play::handle_play; use crate::session::{HandObject, PlayerColor, Session}; use crate::template::{IndexTemplate, SessionTemplate}; @@ -24,6 +20,7 @@ use axum::response::{ErrorResponse, Html, IntoResponse, Redirect, Response}; use axum::routing::{any, get, put}; use axum::{Json, Router}; use rust_embed::Embed; +use std::array; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Mutex, RwLock}; @@ -32,15 +29,14 @@ use std::sync::{Arc, Mutex, RwLock}; #[folder = "assets/"] struct EmbedAsset; +#[derive(Default)] struct AppState { - sessions: RwLock<HashMap<String, Arc<Mutex<Session>>>>, + sessions: RwLock<HashMap<String, Mutex<Session>>>, } impl AppState { fn new() -> Self { - AppState { - sessions: RwLock::new(HashMap::new()), - } + Self::default() } } @@ -120,7 +116,7 @@ async fn create_session( let mut sessions = state.sessions.write().unwrap(); let session = Session::new(name); - sessions.insert(id, Arc::new(Mutex::new(session))); + sessions.insert(id, Mutex::new(session)); StatusCode::CREATED } @@ -130,7 +126,7 @@ async fn update_hands( State(state): State<Arc<AppState>>, Json(payload): Json<HashMap<String, Vec<HandObject>>>, ) -> StatusCode { - let mut sessions = state.sessions.write().unwrap(); + let sessions = state.sessions.read().unwrap(); let hand = array::from_fn(|i| { let color = PlayerColor::try_from(i); @@ -144,7 +140,7 @@ async fn update_hands( } }); - match sessions.get_mut(&id) { + match sessions.get(&id) { Some(session) => { let mut session = session.lock().unwrap(); diff --git a/src/play.rs b/src/play.rs index 2f364c0..998ac29 100644 --- a/src/play.rs +++ b/src/play.rs @@ -1,13 +1,11 @@ use crate::AppState; -use crate::session::{HandObject, PlayerColor, Session}; +use crate::session::{HandObject, PlayerColor}; use axum::extract::ws::{Message, Utf8Bytes, WebSocket}; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::sync::{Arc, Mutex, RwLock, Weak}; -use tokio::sync::broadcast::Receiver; -use tokio::sync::broadcast::error::RecvError; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; +use std::mem; +use std::sync::{Arc, Mutex}; +use tokio::sync::{broadcast, mpsc, oneshot}; #[derive(Deserialize)] enum IncomingPlayMessage { @@ -20,7 +18,6 @@ enum IncomingPlayMessage { enum OutgoingPlayMessage { Initialize { colors: Vec<String> }, Hand(Vec<HandObject>), - // TODO: include error details Error, } @@ -29,10 +26,49 @@ pub enum PlayUpdate { HandUpdate([Vec<HandObject>; PlayerColor::COUNT]), } +// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code) +#[allow(dead_code)] +pub enum Error { + BadJson(serde_json::Error), + Closed, + InvalidSession(String), + InvalidColor, +} + +struct PlayState { + session: Option<String>, + color: PlayerColor, + update_cancel_tx: oneshot::Sender<()>, +} +impl PlayState { + pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self { + Self { + session: Default::default(), + color: Default::default(), + update_cancel_tx, + } + } +} + +impl From<serde_json::Error> for Error { + fn from(value: serde_json::Error) -> Self { + Self::BadJson(value) + } +} + +impl From<mpsc::error::SendError<OutgoingPlayMessage>> for Error { + fn from(_: mpsc::error::SendError<OutgoingPlayMessage>) -> Self { + Self::Closed + } +} + pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) { let (mut sender, mut receiver) = socket.split(); let (sender_tx, mut sender_rx) = mpsc::channel(2); + let (update_cancel_tx, _) = oneshot::channel(); + let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx))); + let mut send_task = tokio::spawn(async move { while let Some(message) = sender_rx.recv().await { let serialized = match serde_json::to_string(&message) { @@ -53,101 +89,42 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) { }); let mut recv_task = { - let sender_tx = sender_tx.clone(); tokio::spawn(async move { - let mut player_session = None; - let player_color = Arc::new(RwLock::new(PlayerColor::Grey)); - while let Some(msg) = receiver.next().await { let Ok(Message::Text(text)) = msg else { continue; }; match serde_json::from_str(text.as_str()) { - Ok(IncomingPlayMessage::Initialize { id }) => { - let session = { - let sessions = app_state.sessions.read().unwrap(); - sessions - .get(&id) - .map(Arc::clone) - .ok_or("Session did not exist") - }; - - match session { - Ok(session) => { - let (colors, update_rx) = { - let session = session.lock().unwrap(); - - let colors: Vec<String> = session.seats.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) - .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) - .map(|color| String::from(color.as_ref())) - .collect(); - let update_rx = session.update_tx.subscribe(); - - (colors, update_rx) - }; - - player_session = Some(Arc::downgrade(&session)); - { - let sender_tx = sender_tx.clone(); - let player_session = Arc::downgrade(&session); - let player_color = player_color.clone(); - - tokio::spawn(async move { - handle_update( - update_rx, - sender_tx, - player_session, - player_color, - ) - .await - }); - } - - let response = OutgoingPlayMessage::Initialize { colors }; - if sender_tx.send(response).await.is_err() { - break; - } + Ok(msg) => { + let result = + handle_play_message(msg, sender_tx.clone(), &state, &app_state).await; + match result { + Ok(_) => (), + Err(Error::Closed) => { + eprintln!("Failed to send play message as the channel closed."); + break; } - Err(err) => { - eprintln!("Failed to access session: {}", err); - let response = OutgoingPlayMessage::Error; - if sender_tx.send(response).await.is_err() { + Err(_) => { + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!( + "Failed to send play message as the channel closed: {}", + err + ); break; } } } } - Ok(IncomingPlayMessage::Color(color)) => { - let Some(session) = - player_session.clone().and_then(|session| session.upgrade()) - else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - let Ok(color) = PlayerColor::try_from(color.as_str()) else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - - let hand = session.lock().unwrap().seats[&color].clone(); - *player_color.write().unwrap() = color; - if sender_tx - .send(OutgoingPlayMessage::Hand(hand)) - .await - .is_err() - { + Err(_) => { + // TODO: include error details + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!("Failed to send play message as the channel closed: {}", err); break; } } - Err(err) => { - eprintln!( - "Encountered an error while handling a message from a player: {}", - err - ); - break; - } } } }) @@ -159,33 +136,114 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) { } } -async fn handle_update( - mut update_rx: Receiver<PlayUpdate>, - sender_tx: Sender<OutgoingPlayMessage>, - _player_session: Weak<Mutex<Session>>, - player_color: Arc<RwLock<PlayerColor>>, -) { - loop { - match update_rx.recv().await { - Ok(PlayUpdate::HandUpdate(hands)) => { - let colors: Vec<String> = hands.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) +async fn handle_play_message( + message: IncomingPlayMessage, + sender_tx: mpsc::Sender<OutgoingPlayMessage>, + state: &Arc<Mutex<PlayState>>, + app_state: &Arc<AppState>, +) -> Result<(), Error> { + match message { + IncomingPlayMessage::Initialize { id } => { + let (colors, update_rx) = { + let sessions = app_state.sessions.read().unwrap(); + let session = sessions.get(&id).map(|session| session.lock().unwrap()); + // The Option is unwrapped with let else here instead of ok_or and propagating + // because the error moves id which would be used later, and the compiler does not + // reason about propagation returning early. + let Some(session) = session else { + return Err(Error::InvalidSession(id)); + }; + + let colors: Vec<String> = session + .seats + .iter() + .enumerate() + .filter(|(_, hand)| !hand.is_empty()) .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) .map(|color| String::from(color.as_ref())) .collect(); - let _ = sender_tx - .send(OutgoingPlayMessage::Initialize { - colors, - }) - .await; - let hand = { - let color = player_color.read().unwrap(); - hands[usize::from(&*color)].to_owned() - }; - let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + let update_rx = session.update_tx.subscribe(); + + (colors, update_rx) + }; + + let update_cancel_rx = { + let mut state = state.lock().unwrap(); + let (update_cancel_tx, update_cancel_rx) = oneshot::channel(); + let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(()); + state.session = Some(id); + + update_cancel_rx + }; + { + let sender_tx = sender_tx.clone(); + let state = state.clone(); + + tokio::spawn(async move { + handle_update(update_rx, update_cancel_rx, sender_tx, state).await; + }); } - Err(RecvError::Closed) => break, - Err(RecvError::Lagged(_)) => continue, + + sender_tx + .send(OutgoingPlayMessage::Initialize { colors }) + .await + .map_err(Error::from) + } + IncomingPlayMessage::Color(color) => { + let hand = { + let mut state = state.lock().unwrap(); + state.color = + PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?; + let name = state + .session + .as_ref() + .ok_or_else(|| Error::InvalidSession(String::default()))?; + let sessions = app_state.sessions.read().unwrap(); + let session = sessions + .get(name) + .ok_or(Error::InvalidSession(name.clone()))?; + + session.lock().unwrap().seats[&state.color].clone() + }; + + sender_tx + .send(OutgoingPlayMessage::Hand(hand)) + .await + .map_err(Error::from) + } + } +} + +async fn handle_update( + mut update_rx: broadcast::Receiver<PlayUpdate>, + mut cancel_rx: oneshot::Receiver<()>, + sender_tx: mpsc::Sender<OutgoingPlayMessage>, + state: Arc<Mutex<PlayState>>, +) { + loop { + tokio::select! { + update = update_rx.recv() => match update { + Ok(PlayUpdate::HandUpdate(hands)) => { + let colors: Vec<String> = hands.iter().enumerate() + .filter(|(_, hand)| !hand.is_empty()) + .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) + .map(|color| String::from(color.as_ref())) + .collect(); + let _ = sender_tx + .send(OutgoingPlayMessage::Initialize { + colors, + }) + .await; + let hand = { + let color = &state.lock().unwrap().color; + hands[color].to_owned() + }; + let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }, + _ = &mut cancel_rx => break, } } } diff --git a/src/session.rs b/src/session.rs index 096ae1a..f4c8fe3 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,6 @@ -use std::array; use crate::play::PlayUpdate; use serde::{Deserialize, Serialize}; +use std::array; use std::ops::Index; use tokio::sync::broadcast; @@ -153,7 +153,7 @@ impl TryFrom<&str> for PlayerColor { "Pink" => Ok(Self::Pink), "Grey" => Ok(Self::Grey), "Black" => Ok(Self::Black), - _ => Err(()) + _ => Err(()), } } } @@ -175,7 +175,7 @@ impl TryFrom<usize> for PlayerColor { 9 => Ok(Self::Pink), 10 => Ok(Self::Grey), 11 => Ok(Self::Black), - _ => Err(()) + _ => Err(()), } } } |
