diff options
| author | Jomar Milan <jomarm@jomarm.com> | 2026-06-09 23:31:48 -0700 |
|---|---|---|
| committer | Jomar Milan <jomarm@jomarm.com> | 2026-06-09 23:31:48 -0700 |
| commit | 13374b7928788e8cdc6c7905209bafdf943dc02e (patch) | |
| tree | 217eb77e6b1e0cfcf07fb6d113db0d733c3ad0e9 /src | |
| parent | acf5e40d02a25a6e99ef23ef61aca8cd261de9d3 (diff) | |
Maintain weak references to sessions for play sockets
Changes in this commit have somewhat mollified my code-smell-o-meter.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main.rs | 36 | ||||
| -rw-r--r-- | src/play.rs | 148 | ||||
| -rw-r--r-- | src/session.rs | 11 |
3 files changed, 117 insertions, 78 deletions
diff --git a/src/main.rs b/src/main.rs index fc0a7bb..9e7f21b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ mod session; mod template; use crate::play::handle_play; -use crate::session::{HandObject, Session}; +use crate::session::{HandObject, Seat, Session}; use crate::template::{IndexTemplate, SessionTemplate}; use askama::Template; use axum::extract::{Path, Query, State, WebSocketUpgrade}; @@ -14,14 +14,14 @@ use axum::{Json, Router}; use rust_embed::Embed; use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; #[derive(Embed)] #[folder = "assets/"] struct EmbedAsset; struct AppState { - sessions: RwLock<HashMap<String, Session>>, + sessions: RwLock<HashMap<String, Arc<Mutex<Session>>>>, } impl AppState { @@ -48,7 +48,7 @@ async fn main() { axum::serve(listener, app).await.unwrap(); } -fn serve_template(template: impl Template) -> Result<Html<String>, &'static str> { +fn serve_template(template: &impl Template) -> Result<Html<String>, &'static str> { template.render().map(Html).map_err(|err| { eprintln!("Template render error: {}", err); "Template render error" @@ -57,7 +57,7 @@ fn serve_template(template: impl Template) -> Result<Html<String>, &'static str> async fn serve_index() -> axum::response::Result<Html<String>> { let template = IndexTemplate; - Ok(serve_template(template)?) + Ok(serve_template(&template)?) } async fn serve_static(Path(path): Path<String>) -> Response { @@ -88,10 +88,15 @@ async fn visit_session( let sessions = state.sessions.read().unwrap(); let session = sessions .get(&id) - .ok_or((StatusCode::NOT_FOUND, "Session does not exist"))?; - - let template = SessionTemplate { id: &id, session }; - Ok(serve_template(template)?) + .ok_or((StatusCode::NOT_FOUND, "Session does not exist"))? + .lock() + .unwrap(); + + let template = SessionTemplate { + id: &id, + session: &session, + }; + Ok(serve_template(&template)?) } async fn create_session( @@ -104,7 +109,7 @@ async fn create_session( let mut sessions = state.sessions.write().unwrap(); let session = Session::new(name); - sessions.insert(id, session); + sessions.insert(id, Arc::new(Mutex::new(session))); StatusCode::CREATED } @@ -118,7 +123,16 @@ async fn update_hands( match sessions.get_mut(&id) { Some(session) => { - session.hands = payload; + let mut session = session.lock().unwrap(); + + for (color, hand) in payload { + let seat = session + .seats + .entry(color) + .or_insert_with(|| Seat { hand: Vec::new() }); + + seat.hand = hand; + } StatusCode::NO_CONTENT } None => StatusCode::NOT_FOUND, diff --git a/src/play.rs b/src/play.rs index e08a895..f401fea 100644 --- a/src/play.rs +++ b/src/play.rs @@ -5,6 +5,16 @@ use serde::{Deserialize, Serialize}; use std::error::Error; use std::sync::Arc; +macro_rules! send_message_or_break { + ($socket:expr, $message:expr) => {{ + let result = send_outgoing_message($socket, $message).await; + if let Err(err) = result { + eprintln!("Failed to send message to socket: {}", err); + break; + } + }}; +} + #[derive(Deserialize)] enum IncomingPlayMessage { Initialize { id: String }, @@ -15,79 +25,93 @@ enum IncomingPlayMessage { enum OutgoingPlayMessage<'a> { Initialize { colors: Vec<&'a String> }, Hand(Vec<&'a HandObject>), + // TODO: include error details + Error, } -struct PlayState { - id: Option<String>, +async fn send_outgoing_message( + socket: &mut WebSocket, + message: &OutgoingPlayMessage<'_>, +) -> Result<(), Box<dyn Error>> { + let serialized = serde_json::to_string(message)?; + socket + .send(Message::Text(Utf8Bytes::from(serialized))) + .await + .map_err(Box::from) } pub async fn handle_play(mut socket: WebSocket, app_state: Arc<AppState>) { - let mut play_state = PlayState { id: None }; + let mut player_session = None; while let Some(msg) = socket.recv().await { - let mut process = async |msg: Result<Message, axum::Error>, - play_state: &mut PlayState| - -> Result<(), Box<dyn Error>> { - let msg: IncomingPlayMessage = serde_json::from_str(msg?.to_text()?)?; - match msg { - IncomingPlayMessage::Initialize { id } => { - // Blocked so that the guard is dropped after cloning the color names, - // preventing a potential deadlock when using .await after sending something - // through the socket, which would be possible if using tokio::sync::Mutex. - // Of course, the sessions HashMap is wrapped in std::sync types instead, which - // does not enable locking the mutex through .await anyway. - let colors: Vec<String> = { - let sessions = app_state.sessions.read().unwrap(); - // TODO: Non-string Error might be useful - let session = sessions.get(&id).ok_or("Session did not exist")?; - - session.hands.keys().cloned().collect() - }; - - play_state.id = Some(id); - - let response = OutgoingPlayMessage::Initialize { - colors: colors.iter().collect(), - }; - socket - .send(Message::Text(Utf8Bytes::from(serde_json::to_string( - &response, - )?))) - .await?; - } - IncomingPlayMessage::Color(color) => { - let hand: Vec<HandObject> = { - let sessions = app_state.sessions.read().unwrap(); - let session = sessions - .get(play_state.id.as_ref().ok_or("No session was joined")?) - .ok_or("Session did not exist")?; + let Ok(Message::Text(text)) = msg else { + continue; + }; - (*session - .hands - .get(&color) - .ok_or("No player seated by that color")?) - .clone() - }; + match serde_json::from_str(text.as_str()) { + Ok(IncomingPlayMessage::Initialize { id }) => { + let session = { + let sessions = app_state.sessions.read().unwrap(); + sessions + .get(&id) + .map(Arc::to_owned) + .ok_or("Session did not exist") + }; - let response = OutgoingPlayMessage::Hand(hand.iter().collect()); - socket - .send(Message::Text(Utf8Bytes::from(serde_json::to_string( - &response, - )?))) - .await?; + match session { + Ok(session) => { + let colors: Vec<String> = session + .lock() + .unwrap() + .seats + .keys() + .map(String::to_owned) + .collect(); + player_session = Some(Arc::downgrade(&session)); + let response = OutgoingPlayMessage::Initialize { + colors: colors.iter().collect(), + }; + send_message_or_break!(&mut socket, &response); + } + Err(err) => { + eprintln!("Failed to access session: {}", err); + let response = OutgoingPlayMessage::Error; + send_message_or_break!(&mut socket, &response); + } } } - Ok(()) - }; - - if let Ok(Message::Text(_)) = msg - && let Err(err) = process(msg, &mut play_state).await - { - eprintln!( - "Encountered an error while handling a message from a player: {}", - err - ); - break; + Ok(IncomingPlayMessage::Color(color)) => { + let Some(session) = player_session.clone().and_then(|session| session.upgrade()) + else { + let response = OutgoingPlayMessage::Error; + send_message_or_break!(&mut socket, &response); + break; + }; + let hand = session + .lock() + .unwrap() + .seats + .get(&color) + .map(|seat| (&seat.hand).to_owned()); + match hand { + Some(hand) => { + // Response constructed here because the inner value of the Option would be dropped outside the match block + let response = OutgoingPlayMessage::Hand(hand.iter().collect()); + send_message_or_break!(&mut socket, &response); + } + None => { + let response = OutgoingPlayMessage::Error; + send_message_or_break!(&mut socket, &response); + } + }; + } + Err(err) => { + eprintln!( + "Encountered an error while handling a message from a player: {}", + err + ); + break; + } } } } diff --git a/src/session.rs b/src/session.rs index e472adf..acd7615 100644 --- a/src/session.rs +++ b/src/session.rs @@ -3,13 +3,10 @@ use std::collections::HashMap; pub struct Session { pub steam_name: String, - pub hands: HashMap<String, Vec<HandObject>>, + pub seats: HashMap<String, Seat>, } -// TODO: The values on these variants will be used in the future and there will be more variants. -// Once this happens, the dead_code lint should no longer be suppressed. #[derive(Clone, Serialize, Deserialize)] -#[allow(dead_code)] pub enum HandObject { CustomDeck(CustomDeck), } @@ -40,11 +37,15 @@ pub struct CustomDeck { card_id: f64, } +pub struct Seat { + pub hand: Vec<HandObject>, +} + impl Session { pub fn new(steam_name: String) -> Self { Session { steam_name, - hands: HashMap::new(), + seats: HashMap::new(), } } } |
