diff options
| author | Jomar Milan <jomarm@jomarm.com> | 2026-06-07 23:40:15 -0700 |
|---|---|---|
| committer | Jomar Milan <jomarm@jomarm.com> | 2026-06-07 23:40:15 -0700 |
| commit | a01f114a641121c77ab80fe43b0f3770458f3afc (patch) | |
| tree | 07f10268e848cfd71c048eb02cd2be4a920ffca5 | |
| parent | a1746f7c30519abc2f3a293a8f02cc0b2015804d (diff) | |
Use RwLock for sessions in app state
| -rw-r--r-- | src/main.rs | 16 | ||||
| -rw-r--r-- | src/play.rs | 30 |
2 files changed, 27 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs index b157b71..a09dd84 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use axum::{Json, Router}; use rust_embed::Embed; use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::{SystemTime, UNIX_EPOCH}; #[derive(Embed)] @@ -22,13 +22,13 @@ use std::time::{SystemTime, UNIX_EPOCH}; struct EmbedAsset; struct AppState { - sessions: Mutex<HashMap<String, Session>>, + sessions: RwLock<HashMap<String, Session>>, } impl AppState { fn new() -> Self { AppState { - sessions: Mutex::new(HashMap::new()), + sessions: RwLock::new(HashMap::new()), } } } @@ -59,7 +59,7 @@ fn serve_template(template: impl Template) -> Response { } async fn serve_index(State(state): State<Arc<AppState>>) -> Response { - let sessions = state.sessions.lock().unwrap(); + let sessions = state.sessions.read().unwrap(); serve_template(IndexTemplate { sessions: &sessions, }) @@ -85,7 +85,7 @@ async fn visit_session( ) -> Response { let passcode = query.get("passcode"); - let sessions = state.sessions.lock().unwrap(); + let sessions = state.sessions.read().unwrap(); match sessions.get(&id) { Some(session) => match passcode { @@ -110,7 +110,7 @@ async fn create_session( .unwrap_or(675603000) .to_string(); - let mut sessions = state.sessions.lock().unwrap(); + let mut sessions = state.sessions.write().unwrap(); let session = Session::new(name, passcode.clone()); sessions.insert(id, session); @@ -119,7 +119,7 @@ async fn create_session( } async fn serve_hands(Path(id): Path<String>, State(state): State<Arc<AppState>>) -> Response { - let sessions = state.sessions.lock().unwrap(); + let sessions = state.sessions.read().unwrap(); match sessions.get(&id) { Some(session) => Json(session.hands.keys().collect::<Vec<_>>()).into_response(), @@ -132,7 +132,7 @@ async fn update_hands( State(state): State<Arc<AppState>>, Json(payload): Json<HashMap<String, Vec<HandObject>>>, ) -> Response { - let mut sessions = state.sessions.lock().unwrap(); + let mut sessions = state.sessions.write().unwrap(); match sessions.get_mut(&id) { Some(session) => { diff --git a/src/play.rs b/src/play.rs index af5314a..574e8b8 100644 --- a/src/play.rs +++ b/src/play.rs @@ -1,9 +1,9 @@ use crate::AppState; +use crate::session::HandObject; use axum::extract::ws::{Message, Utf8Bytes, WebSocket}; use serde::{Deserialize, Serialize}; use std::error::Error; use std::sync::Arc; -use crate::session::HandObject; #[derive(Deserialize)] enum IncomingPlayMessage { @@ -25,17 +25,19 @@ pub async fn handle_play(mut socket: WebSocket, app_state: Arc<AppState>) { let mut play_state = PlayState { id: None }; while let Some(msg) = socket.recv().await { - let mut process = async |msg: Result<Message, axum::Error>, play_state: &mut PlayState| -> Result<(), Box<dyn Error>> { + let mut process = async |msg: Result<Message, axum::Error>, + play_state: &mut PlayState| + -> Result<(), Box<dyn Error>> { let msg: IncomingPlayMessage = serde_json::from_str(msg?.to_text()?)?; match msg { IncomingPlayMessage::Initialize { id } => { - // Blocked so that the mutex guard is dropped after cloning the color names, + // Blocked so that the guard is dropped after cloning the color names, // preventing a potential deadlock when using .await after sending something // through the socket, which would be possible if using tokio::sync::Mutex. - // Of course, the sessions Mutex is std::sync::Mutex instead, which does not - // allow locking the mutex through .await anyway. + // Of course, the sessions HashMap is wrapped instd::sync types instead, which + // does not enable locking the mutex through .await anyway. let colors: Vec<String> = { - let sessions = app_state.sessions.lock().unwrap(); + let sessions = app_state.sessions.read().unwrap(); // TODO: Non-string Error might be useful let session = sessions.get(&id).ok_or("Session did not exist")?; @@ -52,13 +54,19 @@ pub async fn handle_play(mut socket: WebSocket, app_state: Arc<AppState>) { &response, )?))) .await?; - }, + } IncomingPlayMessage::Color(color) => { let hand: Vec<HandObject> = { - let sessions = app_state.sessions.lock().unwrap(); - let session = sessions.get(play_state.id.as_ref().ok_or("No session was joined")?).ok_or("Session did not exist")?; + let sessions = app_state.sessions.read().unwrap(); + let session = sessions + .get(play_state.id.as_ref().ok_or("No session was joined")?) + .ok_or("Session did not exist")?; - (*session.hands.get(&color).ok_or("No player seated by that color")?).clone() + (*session + .hands + .get(&color) + .ok_or("No player seated by that color")?) + .clone() }; let response = OutgoingPlayMessage::Hand(hand.iter().collect()); @@ -67,7 +75,7 @@ pub async fn handle_play(mut socket: WebSocket, app_state: Arc<AppState>) { &response, )?))) .await?; - }, + } } Ok(()) }; |
