From 617a0b8338e61b1a3625f0cc7fa4f543cb23d701 Mon Sep 17 00:00:00 2001 From: Jomar Milan Date: Sat, 20 Jun 2026 10:46:02 -0700 Subject: Use separate function to handle play messages Also remove Arc from the HashMap --- src/play.rs | 276 ++++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 167 insertions(+), 109 deletions(-) (limited to 'src/play.rs') diff --git a/src/play.rs b/src/play.rs index 2f364c0..998ac29 100644 --- a/src/play.rs +++ b/src/play.rs @@ -1,13 +1,11 @@ use crate::AppState; -use crate::session::{HandObject, PlayerColor, Session}; +use crate::session::{HandObject, PlayerColor}; use axum::extract::ws::{Message, Utf8Bytes, WebSocket}; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::sync::{Arc, Mutex, RwLock, Weak}; -use tokio::sync::broadcast::Receiver; -use tokio::sync::broadcast::error::RecvError; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; +use std::mem; +use std::sync::{Arc, Mutex}; +use tokio::sync::{broadcast, mpsc, oneshot}; #[derive(Deserialize)] enum IncomingPlayMessage { @@ -20,7 +18,6 @@ enum IncomingPlayMessage { enum OutgoingPlayMessage { Initialize { colors: Vec }, Hand(Vec), - // TODO: include error details Error, } @@ -29,10 +26,49 @@ pub enum PlayUpdate { HandUpdate([Vec; PlayerColor::COUNT]), } +// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code) +#[allow(dead_code)] +pub enum Error { + BadJson(serde_json::Error), + Closed, + InvalidSession(String), + InvalidColor, +} + +struct PlayState { + session: Option, + color: PlayerColor, + update_cancel_tx: oneshot::Sender<()>, +} +impl PlayState { + pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self { + Self { + session: Default::default(), + color: Default::default(), + update_cancel_tx, + } + } +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Self::BadJson(value) + } +} + +impl From> for Error { + fn from(_: mpsc::error::SendError) -> Self { + Self::Closed + } +} + pub async fn handle_play(socket: WebSocket, app_state: Arc) { let (mut sender, mut receiver) = socket.split(); let (sender_tx, mut sender_rx) = mpsc::channel(2); + let (update_cancel_tx, _) = oneshot::channel(); + let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx))); + let mut send_task = tokio::spawn(async move { while let Some(message) = sender_rx.recv().await { let serialized = match serde_json::to_string(&message) { @@ -53,101 +89,42 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc) { }); let mut recv_task = { - let sender_tx = sender_tx.clone(); tokio::spawn(async move { - let mut player_session = None; - let player_color = Arc::new(RwLock::new(PlayerColor::Grey)); - while let Some(msg) = receiver.next().await { let Ok(Message::Text(text)) = msg else { continue; }; match serde_json::from_str(text.as_str()) { - Ok(IncomingPlayMessage::Initialize { id }) => { - let session = { - let sessions = app_state.sessions.read().unwrap(); - sessions - .get(&id) - .map(Arc::clone) - .ok_or("Session did not exist") - }; - - match session { - Ok(session) => { - let (colors, update_rx) = { - let session = session.lock().unwrap(); - - let colors: Vec = session.seats.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) - .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) - .map(|color| String::from(color.as_ref())) - .collect(); - let update_rx = session.update_tx.subscribe(); - - (colors, update_rx) - }; - - player_session = Some(Arc::downgrade(&session)); - { - let sender_tx = sender_tx.clone(); - let player_session = Arc::downgrade(&session); - let player_color = player_color.clone(); - - tokio::spawn(async move { - handle_update( - update_rx, - sender_tx, - player_session, - player_color, - ) - .await - }); - } - - let response = OutgoingPlayMessage::Initialize { colors }; - if sender_tx.send(response).await.is_err() { - break; - } + Ok(msg) => { + let result = + handle_play_message(msg, sender_tx.clone(), &state, &app_state).await; + match result { + Ok(_) => (), + Err(Error::Closed) => { + eprintln!("Failed to send play message as the channel closed."); + break; } - Err(err) => { - eprintln!("Failed to access session: {}", err); - let response = OutgoingPlayMessage::Error; - if sender_tx.send(response).await.is_err() { + Err(_) => { + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!( + "Failed to send play message as the channel closed: {}", + err + ); break; } } } } - Ok(IncomingPlayMessage::Color(color)) => { - let Some(session) = - player_session.clone().and_then(|session| session.upgrade()) - else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - let Ok(color) = PlayerColor::try_from(color.as_str()) else { - let _ = sender_tx.send(OutgoingPlayMessage::Error).await; - break - }; - - let hand = session.lock().unwrap().seats[&color].clone(); - *player_color.write().unwrap() = color; - if sender_tx - .send(OutgoingPlayMessage::Hand(hand)) - .await - .is_err() - { + Err(_) => { + // TODO: include error details + let result = sender_tx.send(OutgoingPlayMessage::Error).await; + if let Err(err) = result { + eprintln!("Failed to send play message as the channel closed: {}", err); break; } } - Err(err) => { - eprintln!( - "Encountered an error while handling a message from a player: {}", - err - ); - break; - } } } }) @@ -159,33 +136,114 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc) { } } -async fn handle_update( - mut update_rx: Receiver, - sender_tx: Sender, - _player_session: Weak>, - player_color: Arc>, -) { - loop { - match update_rx.recv().await { - Ok(PlayUpdate::HandUpdate(hands)) => { - let colors: Vec = hands.iter().enumerate() - .filter(|(_, hand)| hand.len() > 0) +async fn handle_play_message( + message: IncomingPlayMessage, + sender_tx: mpsc::Sender, + state: &Arc>, + app_state: &Arc, +) -> Result<(), Error> { + match message { + IncomingPlayMessage::Initialize { id } => { + let (colors, update_rx) = { + let sessions = app_state.sessions.read().unwrap(); + let session = sessions.get(&id).map(|session| session.lock().unwrap()); + // The Option is unwrapped with let else here instead of ok_or and propagating + // because the error moves id which would be used later, and the compiler does not + // reason about propagation returning early. + let Some(session) = session else { + return Err(Error::InvalidSession(id)); + }; + + let colors: Vec = session + .seats + .iter() + .enumerate() + .filter(|(_, hand)| !hand.is_empty()) .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) .map(|color| String::from(color.as_ref())) .collect(); - let _ = sender_tx - .send(OutgoingPlayMessage::Initialize { - colors, - }) - .await; - let hand = { - let color = player_color.read().unwrap(); - hands[usize::from(&*color)].to_owned() - }; - let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + let update_rx = session.update_tx.subscribe(); + + (colors, update_rx) + }; + + let update_cancel_rx = { + let mut state = state.lock().unwrap(); + let (update_cancel_tx, update_cancel_rx) = oneshot::channel(); + let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(()); + state.session = Some(id); + + update_cancel_rx + }; + { + let sender_tx = sender_tx.clone(); + let state = state.clone(); + + tokio::spawn(async move { + handle_update(update_rx, update_cancel_rx, sender_tx, state).await; + }); } - Err(RecvError::Closed) => break, - Err(RecvError::Lagged(_)) => continue, + + sender_tx + .send(OutgoingPlayMessage::Initialize { colors }) + .await + .map_err(Error::from) + } + IncomingPlayMessage::Color(color) => { + let hand = { + let mut state = state.lock().unwrap(); + state.color = + PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?; + let name = state + .session + .as_ref() + .ok_or_else(|| Error::InvalidSession(String::default()))?; + let sessions = app_state.sessions.read().unwrap(); + let session = sessions + .get(name) + .ok_or(Error::InvalidSession(name.clone()))?; + + session.lock().unwrap().seats[&state.color].clone() + }; + + sender_tx + .send(OutgoingPlayMessage::Hand(hand)) + .await + .map_err(Error::from) + } + } +} + +async fn handle_update( + mut update_rx: broadcast::Receiver, + mut cancel_rx: oneshot::Receiver<()>, + sender_tx: mpsc::Sender, + state: Arc>, +) { + loop { + tokio::select! { + update = update_rx.recv() => match update { + Ok(PlayUpdate::HandUpdate(hands)) => { + let colors: Vec = hands.iter().enumerate() + .filter(|(_, hand)| !hand.is_empty()) + .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) + .map(|color| String::from(color.as_ref())) + .collect(); + let _ = sender_tx + .send(OutgoingPlayMessage::Initialize { + colors, + }) + .await; + let hand = { + let color = &state.lock().unwrap().color; + hands[color].to_owned() + }; + let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }, + _ = &mut cancel_rx => break, } } } -- cgit v1.2.3