From bc7e613204a1dbc2f5b37761a6649658effe2483 Mon Sep 17 00:00:00 2001 From: Jomar Milan Date: Mon, 22 Jun 2026 19:45:08 -0700 Subject: Move crate::play to crate::app::socket --- src/app.rs | 2 + src/app/socket.rs | 249 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 9 +- src/play.rs | 244 ---------------------------------------------------- src/session.rs | 2 +- 5 files changed, 258 insertions(+), 248 deletions(-) create mode 100644 src/app/socket.rs delete mode 100644 src/play.rs (limited to 'src') diff --git a/src/app.rs b/src/app.rs index 43e2ac6..c04bac3 100644 --- a/src/app.rs +++ b/src/app.rs @@ -4,6 +4,8 @@ use crate::session::Session; use std::collections::HashMap; use std::sync::{Mutex, MutexGuard, RwLock}; +pub mod socket; + /// Provider of the state that Tabletop Ambulator needs to keep track of. /// /// Provides the app state, including runtime data such as the active game diff --git a/src/app/socket.rs b/src/app/socket.rs new file mode 100644 index 0000000..08e1db6 --- /dev/null +++ b/src/app/socket.rs @@ -0,0 +1,249 @@ +use crate::AppState; +use crate::session::{HandObject, PlayerColor}; +use axum::extract::ws::{Message, Utf8Bytes}; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::mem; +use std::sync::{Arc, Mutex}; +use tokio::sync::{broadcast, mpsc, oneshot}; + +// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code) +#[allow(dead_code)] +pub enum Error { + BadJson(serde_json::Error), + Closed, + InvalidSession(String), + InvalidColor, +} + +#[derive(Clone)] +pub enum PlayUpdate { + HandUpdate([Vec; PlayerColor::COUNT]), +} + +#[derive(Deserialize)] +enum IncomingMessage { + Initialize { id: String }, + Color(String), +} + +// TODO: Maybe derive Clone, reference interior vals +#[derive(Serialize)] +enum OutgoingMessage { + Initialize { colors: Vec }, + Hand(Vec), + Error, +} + +struct PlayState { + session: Option, + color: PlayerColor, + update_cancel_tx: oneshot::Sender<()>, +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Self::BadJson(value) + } +} + +impl From> for Error { + fn from(_: mpsc::error::SendError) -> Self { + Self::Closed + } +} + +impl PlayState { + pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self { + Self { + session: Default::default(), + color: Default::default(), + update_cancel_tx, + } + } +} + +pub async fn handle_play(mut sender: S, mut receiver: R, app_state: Arc) +where + S: Sink + Unpin + Send + 'static, + R: Stream> + Unpin + Send + 'static, +{ + let (sender_tx, mut sender_rx) = mpsc::channel(2); + + let (update_cancel_tx, _) = oneshot::channel(); + let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx))); + + let mut send_task = tokio::spawn(async move { + while let Some(message) = sender_rx.recv().await { + let serialized = match serde_json::to_string(&message) { + Ok(serialized) => serialized, + Err(err) => { + eprintln!("Failed to serialize outgoing websocket message: {}", err); + break; + } + }; + if let Err(err) = sender + .send(Message::Text(Utf8Bytes::from(serialized))) + .await + { + eprintln!("Failed to send serialized websocket message: {:?}", err); + break; + } + } + }); + + let mut recv_task = { + tokio::spawn(async move { + while let Some(msg) = receiver.next().await { + let Ok(Message::Text(text)) = msg else { + continue; + }; + + match serde_json::from_str(text.as_str()) { + Ok(msg) => { + let result = + handle_play_message(msg, sender_tx.clone(), &state, &app_state).await; + match result { + Ok(_) => (), + Err(Error::Closed) => { + eprintln!("Failed to send play message as the channel closed."); + break; + } + Err(_) => { + let result = sender_tx.send(OutgoingMessage::Error).await; + if let Err(err) = result { + eprintln!( + "Failed to send play message as the channel closed: {}", + err + ); + break; + } + } + } + } + Err(_) => { + // TODO: include error details + let result = sender_tx.send(OutgoingMessage::Error).await; + if let Err(err) = result { + eprintln!("Failed to send play message as the channel closed: {}", err); + break; + } + } + } + } + }) + }; + + tokio::select! { + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), + } +} + +async fn handle_play_message( + message: IncomingMessage, + sender_tx: mpsc::Sender, + state: &Arc>, + app_state: &Arc, +) -> Result<(), Error> { + match message { + IncomingMessage::Initialize { id } => { + let data_opt = app_state.with_session(id.as_str(), |session| { + let colors: Vec = session + .seats + .iter() + .enumerate() + .filter(|(_, hand)| !hand.is_empty()) + .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) + .map(|color| String::from(color.as_ref())) + .collect(); + let update_rx = session.update_tx.subscribe(); + + (colors, update_rx) + }); + // let else used instead of propagating Option::ok_or_else because compiler wouldn't + // know about early return when moving id + let Some((colors, update_rx)) = data_opt else { + return Err(Error::InvalidSession(id)); + }; + + let update_cancel_rx = { + let mut state = state.lock().unwrap(); + let (update_cancel_tx, update_cancel_rx) = oneshot::channel(); + let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(()); + state.session = Some(id); + + update_cancel_rx + }; + { + let sender_tx = sender_tx.clone(); + let state = state.clone(); + + tokio::spawn(async move { + handle_update(update_rx, update_cancel_rx, sender_tx, state).await; + }); + } + + sender_tx + .send(OutgoingMessage::Initialize { colors }) + .await + .map_err(Error::from) + } + IncomingMessage::Color(color) => { + let hand = { + let mut state = state.lock().unwrap(); + state.color = + PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?; + + let name = state + .session + .as_ref() + .ok_or_else(|| Error::InvalidSession(String::default()))?; + + app_state + .with_session(name.as_str(), |session| session.seats[&state.color].clone()) + .ok_or_else(|| Error::InvalidSession(name.clone()))? + }; + + sender_tx + .send(OutgoingMessage::Hand(hand)) + .await + .map_err(Error::from) + } + } +} + +async fn handle_update( + mut update_rx: broadcast::Receiver, + mut cancel_rx: oneshot::Receiver<()>, + sender_tx: mpsc::Sender, + state: Arc>, +) { + loop { + tokio::select! { + update = update_rx.recv() => match update { + Ok(PlayUpdate::HandUpdate(hands)) => { + let colors: Vec = hands.iter().enumerate() + .filter(|(_, hand)| !hand.is_empty()) + .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) + .map(|color| String::from(color.as_ref())) + .collect(); + let _ = sender_tx + .send(OutgoingMessage::Initialize { + colors, + }) + .await; + let hand = { + let color = &state.lock().unwrap().color; + hands[color].to_owned() + }; + let _ = sender_tx.send(OutgoingMessage::Hand(hand)).await; + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }, + _ = &mut cancel_rx => break, + } + } +} diff --git a/src/main.rs b/src/main.rs index 38bbbf2..a4e51c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,20 +7,20 @@ #![warn(missing_docs, missing_debug_implementations)] mod app; -mod play; mod session; mod template; use crate::app::AppState; -use crate::play::handle_play; use crate::session::{HandObject, PlayerColor}; use crate::template::{IndexTemplate, SessionTemplate}; +use app::socket::handle_play; use askama::Template; use axum::extract::{Path, Query, State, WebSocketUpgrade}; use axum::http::{StatusCode, header}; use axum::response::{ErrorResponse, Html, IntoResponse, Redirect, Response}; use axum::routing::{any, get, put}; use axum::{Json, Router}; +use futures_util::StreamExt; use rust_embed::Embed; use std::array; use std::collections::HashMap; @@ -132,5 +132,8 @@ async fn update_hands( } async fn upgrade_play(ws: WebSocketUpgrade, State(state): State>) -> Response { - ws.on_upgrade(|socket| handle_play(socket, state)) + ws.on_upgrade(|socket| { + let (sender, receiver) = socket.split(); + handle_play(sender, receiver, state) + }) } diff --git a/src/play.rs b/src/play.rs deleted file mode 100644 index 0ec9758..0000000 --- a/src/play.rs +++ /dev/null @@ -1,244 +0,0 @@ -use crate::AppState; -use crate::session::{HandObject, PlayerColor}; -use axum::extract::ws::{Message, Utf8Bytes, WebSocket}; -use futures_util::{SinkExt, StreamExt}; -use serde::{Deserialize, Serialize}; -use std::mem; -use std::sync::{Arc, Mutex}; -use tokio::sync::{broadcast, mpsc, oneshot}; - -#[derive(Deserialize)] -enum IncomingPlayMessage { - Initialize { id: String }, - Color(String), -} - -// TODO: Maybe derive Clone, reference interior vals -#[derive(Serialize)] -enum OutgoingPlayMessage { - Initialize { colors: Vec }, - Hand(Vec), - Error, -} - -#[derive(Clone)] -pub enum PlayUpdate { - HandUpdate([Vec; PlayerColor::COUNT]), -} - -// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code) -#[allow(dead_code)] -pub enum Error { - BadJson(serde_json::Error), - Closed, - InvalidSession(String), - InvalidColor, -} - -struct PlayState { - session: Option, - color: PlayerColor, - update_cancel_tx: oneshot::Sender<()>, -} -impl PlayState { - pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self { - Self { - session: Default::default(), - color: Default::default(), - update_cancel_tx, - } - } -} - -impl From for Error { - fn from(value: serde_json::Error) -> Self { - Self::BadJson(value) - } -} - -impl From> for Error { - fn from(_: mpsc::error::SendError) -> Self { - Self::Closed - } -} - -pub async fn handle_play(socket: WebSocket, app_state: Arc) { - let (mut sender, mut receiver) = socket.split(); - let (sender_tx, mut sender_rx) = mpsc::channel(2); - - let (update_cancel_tx, _) = oneshot::channel(); - let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx))); - - let mut send_task = tokio::spawn(async move { - while let Some(message) = sender_rx.recv().await { - let serialized = match serde_json::to_string(&message) { - Ok(serialized) => serialized, - Err(err) => { - eprintln!("Failed to serialize outgoing websocket message: {}", err); - break; - } - }; - if let Err(err) = sender - .send(Message::Text(Utf8Bytes::from(serialized))) - .await - { - eprintln!("Failed to send serialized websocket message: {}", err); - break; - } - } - }); - - let mut recv_task = { - tokio::spawn(async move { - while let Some(msg) = receiver.next().await { - let Ok(Message::Text(text)) = msg else { - continue; - }; - - match serde_json::from_str(text.as_str()) { - Ok(msg) => { - let result = - handle_play_message(msg, sender_tx.clone(), &state, &app_state).await; - match result { - Ok(_) => (), - Err(Error::Closed) => { - eprintln!("Failed to send play message as the channel closed."); - break; - } - Err(_) => { - let result = sender_tx.send(OutgoingPlayMessage::Error).await; - if let Err(err) = result { - eprintln!( - "Failed to send play message as the channel closed: {}", - err - ); - break; - } - } - } - } - Err(_) => { - // TODO: include error details - let result = sender_tx.send(OutgoingPlayMessage::Error).await; - if let Err(err) = result { - eprintln!("Failed to send play message as the channel closed: {}", err); - break; - } - } - } - } - }) - }; - - tokio::select! { - _ = &mut send_task => recv_task.abort(), - _ = &mut recv_task => send_task.abort(), - } -} - -async fn handle_play_message( - message: IncomingPlayMessage, - sender_tx: mpsc::Sender, - state: &Arc>, - app_state: &Arc, -) -> Result<(), Error> { - match message { - IncomingPlayMessage::Initialize { id } => { - let data_opt = app_state.with_session(id.as_str(), |session| { - let colors: Vec = session - .seats - .iter() - .enumerate() - .filter(|(_, hand)| !hand.is_empty()) - .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) - .map(|color| String::from(color.as_ref())) - .collect(); - let update_rx = session.update_tx.subscribe(); - - (colors, update_rx) - }); - // let else used instead of propagating Option::ok_or_else because compiler wouldn't - // know about early return when moving id - let Some((colors, update_rx)) = data_opt else { - return Err(Error::InvalidSession(id)); - }; - - let update_cancel_rx = { - let mut state = state.lock().unwrap(); - let (update_cancel_tx, update_cancel_rx) = oneshot::channel(); - let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(()); - state.session = Some(id); - - update_cancel_rx - }; - { - let sender_tx = sender_tx.clone(); - let state = state.clone(); - - tokio::spawn(async move { - handle_update(update_rx, update_cancel_rx, sender_tx, state).await; - }); - } - - sender_tx - .send(OutgoingPlayMessage::Initialize { colors }) - .await - .map_err(Error::from) - } - IncomingPlayMessage::Color(color) => { - let hand = { - let mut state = state.lock().unwrap(); - state.color = - PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?; - - let name = state - .session - .as_ref() - .ok_or_else(|| Error::InvalidSession(String::default()))?; - - app_state - .with_session(name.as_str(), |session| session.seats[&state.color].clone()) - .ok_or_else(|| Error::InvalidSession(name.clone()))? - }; - - sender_tx - .send(OutgoingPlayMessage::Hand(hand)) - .await - .map_err(Error::from) - } - } -} - -async fn handle_update( - mut update_rx: broadcast::Receiver, - mut cancel_rx: oneshot::Receiver<()>, - sender_tx: mpsc::Sender, - state: Arc>, -) { - loop { - tokio::select! { - update = update_rx.recv() => match update { - Ok(PlayUpdate::HandUpdate(hands)) => { - let colors: Vec = hands.iter().enumerate() - .filter(|(_, hand)| !hand.is_empty()) - .flat_map(|(index, _)| PlayerColor::try_from(index).ok()) - .map(|color| String::from(color.as_ref())) - .collect(); - let _ = sender_tx - .send(OutgoingPlayMessage::Initialize { - colors, - }) - .await; - let hand = { - let color = &state.lock().unwrap().color; - hands[color].to_owned() - }; - let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await; - } - Err(broadcast::error::RecvError::Closed) => break, - Err(broadcast::error::RecvError::Lagged(_)) => continue, - }, - _ = &mut cancel_rx => break, - } - } -} diff --git a/src/session.rs b/src/session.rs index c15bfda..246551e 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,4 +1,4 @@ -use crate::play::PlayUpdate; +use crate::app::socket::PlayUpdate; use serde::{Deserialize, Serialize}; use std::array; use std::ops::Index; -- cgit v1.2.3