summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main.rs20
-rw-r--r--src/play.rs276
-rw-r--r--src/session.rs6
3 files changed, 178 insertions, 124 deletions
diff --git a/src/main.rs b/src/main.rs
index c05d3c6..b423f6b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -4,16 +4,12 @@
//! access to players to access and manage the contents of their hands from without using the game
//! screen by using a web browser.
-#![warn(
- missing_docs,
- missing_debug_implementations
-)]
+#![warn(missing_docs, missing_debug_implementations)]
mod play;
mod session;
mod template;
-use std::array;
use crate::play::handle_play;
use crate::session::{HandObject, PlayerColor, Session};
use crate::template::{IndexTemplate, SessionTemplate};
@@ -24,6 +20,7 @@ use axum::response::{ErrorResponse, Html, IntoResponse, Redirect, Response};
use axum::routing::{any, get, put};
use axum::{Json, Router};
use rust_embed::Embed;
+use std::array;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock};
@@ -32,15 +29,14 @@ use std::sync::{Arc, Mutex, RwLock};
#[folder = "assets/"]
struct EmbedAsset;
+#[derive(Default)]
struct AppState {
- sessions: RwLock<HashMap<String, Arc<Mutex<Session>>>>,
+ sessions: RwLock<HashMap<String, Mutex<Session>>>,
}
impl AppState {
fn new() -> Self {
- AppState {
- sessions: RwLock::new(HashMap::new()),
- }
+ Self::default()
}
}
@@ -120,7 +116,7 @@ async fn create_session(
let mut sessions = state.sessions.write().unwrap();
let session = Session::new(name);
- sessions.insert(id, Arc::new(Mutex::new(session)));
+ sessions.insert(id, Mutex::new(session));
StatusCode::CREATED
}
@@ -130,7 +126,7 @@ async fn update_hands(
State(state): State<Arc<AppState>>,
Json(payload): Json<HashMap<String, Vec<HandObject>>>,
) -> StatusCode {
- let mut sessions = state.sessions.write().unwrap();
+ let sessions = state.sessions.read().unwrap();
let hand = array::from_fn(|i| {
let color = PlayerColor::try_from(i);
@@ -144,7 +140,7 @@ async fn update_hands(
}
});
- match sessions.get_mut(&id) {
+ match sessions.get(&id) {
Some(session) => {
let mut session = session.lock().unwrap();
diff --git a/src/play.rs b/src/play.rs
index 2f364c0..998ac29 100644
--- a/src/play.rs
+++ b/src/play.rs
@@ -1,13 +1,11 @@
use crate::AppState;
-use crate::session::{HandObject, PlayerColor, Session};
+use crate::session::{HandObject, PlayerColor};
use axum::extract::ws::{Message, Utf8Bytes, WebSocket};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
-use std::sync::{Arc, Mutex, RwLock, Weak};
-use tokio::sync::broadcast::Receiver;
-use tokio::sync::broadcast::error::RecvError;
-use tokio::sync::mpsc;
-use tokio::sync::mpsc::Sender;
+use std::mem;
+use std::sync::{Arc, Mutex};
+use tokio::sync::{broadcast, mpsc, oneshot};
#[derive(Deserialize)]
enum IncomingPlayMessage {
@@ -20,7 +18,6 @@ enum IncomingPlayMessage {
enum OutgoingPlayMessage {
Initialize { colors: Vec<String> },
Hand(Vec<HandObject>),
- // TODO: include error details
Error,
}
@@ -29,10 +26,49 @@ pub enum PlayUpdate {
HandUpdate([Vec<HandObject>; PlayerColor::COUNT]),
}
+// TODO: Use in OutgoingPlayMessage::Error, then remove allow(dead_code)
+#[allow(dead_code)]
+pub enum Error {
+ BadJson(serde_json::Error),
+ Closed,
+ InvalidSession(String),
+ InvalidColor,
+}
+
+struct PlayState {
+ session: Option<String>,
+ color: PlayerColor,
+ update_cancel_tx: oneshot::Sender<()>,
+}
+impl PlayState {
+ pub fn new(update_cancel_tx: oneshot::Sender<()>) -> Self {
+ Self {
+ session: Default::default(),
+ color: Default::default(),
+ update_cancel_tx,
+ }
+ }
+}
+
+impl From<serde_json::Error> for Error {
+ fn from(value: serde_json::Error) -> Self {
+ Self::BadJson(value)
+ }
+}
+
+impl From<mpsc::error::SendError<OutgoingPlayMessage>> for Error {
+ fn from(_: mpsc::error::SendError<OutgoingPlayMessage>) -> Self {
+ Self::Closed
+ }
+}
+
pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
let (sender_tx, mut sender_rx) = mpsc::channel(2);
+ let (update_cancel_tx, _) = oneshot::channel();
+ let state = Arc::new(Mutex::new(PlayState::new(update_cancel_tx)));
+
let mut send_task = tokio::spawn(async move {
while let Some(message) = sender_rx.recv().await {
let serialized = match serde_json::to_string(&message) {
@@ -53,101 +89,42 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) {
});
let mut recv_task = {
- let sender_tx = sender_tx.clone();
tokio::spawn(async move {
- let mut player_session = None;
- let player_color = Arc::new(RwLock::new(PlayerColor::Grey));
-
while let Some(msg) = receiver.next().await {
let Ok(Message::Text(text)) = msg else {
continue;
};
match serde_json::from_str(text.as_str()) {
- Ok(IncomingPlayMessage::Initialize { id }) => {
- let session = {
- let sessions = app_state.sessions.read().unwrap();
- sessions
- .get(&id)
- .map(Arc::clone)
- .ok_or("Session did not exist")
- };
-
- match session {
- Ok(session) => {
- let (colors, update_rx) = {
- let session = session.lock().unwrap();
-
- let colors: Vec<String> = session.seats.iter().enumerate()
- .filter(|(_, hand)| hand.len() > 0)
- .flat_map(|(index, _)| PlayerColor::try_from(index).ok())
- .map(|color| String::from(color.as_ref()))
- .collect();
- let update_rx = session.update_tx.subscribe();
-
- (colors, update_rx)
- };
-
- player_session = Some(Arc::downgrade(&session));
- {
- let sender_tx = sender_tx.clone();
- let player_session = Arc::downgrade(&session);
- let player_color = player_color.clone();
-
- tokio::spawn(async move {
- handle_update(
- update_rx,
- sender_tx,
- player_session,
- player_color,
- )
- .await
- });
- }
-
- let response = OutgoingPlayMessage::Initialize { colors };
- if sender_tx.send(response).await.is_err() {
- break;
- }
+ Ok(msg) => {
+ let result =
+ handle_play_message(msg, sender_tx.clone(), &state, &app_state).await;
+ match result {
+ Ok(_) => (),
+ Err(Error::Closed) => {
+ eprintln!("Failed to send play message as the channel closed.");
+ break;
}
- Err(err) => {
- eprintln!("Failed to access session: {}", err);
- let response = OutgoingPlayMessage::Error;
- if sender_tx.send(response).await.is_err() {
+ Err(_) => {
+ let result = sender_tx.send(OutgoingPlayMessage::Error).await;
+ if let Err(err) = result {
+ eprintln!(
+ "Failed to send play message as the channel closed: {}",
+ err
+ );
break;
}
}
}
}
- Ok(IncomingPlayMessage::Color(color)) => {
- let Some(session) =
- player_session.clone().and_then(|session| session.upgrade())
- else {
- let _ = sender_tx.send(OutgoingPlayMessage::Error).await;
- break
- };
- let Ok(color) = PlayerColor::try_from(color.as_str()) else {
- let _ = sender_tx.send(OutgoingPlayMessage::Error).await;
- break
- };
-
- let hand = session.lock().unwrap().seats[&color].clone();
- *player_color.write().unwrap() = color;
- if sender_tx
- .send(OutgoingPlayMessage::Hand(hand))
- .await
- .is_err()
- {
+ Err(_) => {
+ // TODO: include error details
+ let result = sender_tx.send(OutgoingPlayMessage::Error).await;
+ if let Err(err) = result {
+ eprintln!("Failed to send play message as the channel closed: {}", err);
break;
}
}
- Err(err) => {
- eprintln!(
- "Encountered an error while handling a message from a player: {}",
- err
- );
- break;
- }
}
}
})
@@ -159,33 +136,114 @@ pub async fn handle_play(socket: WebSocket, app_state: Arc<AppState>) {
}
}
-async fn handle_update(
- mut update_rx: Receiver<PlayUpdate>,
- sender_tx: Sender<OutgoingPlayMessage>,
- _player_session: Weak<Mutex<Session>>,
- player_color: Arc<RwLock<PlayerColor>>,
-) {
- loop {
- match update_rx.recv().await {
- Ok(PlayUpdate::HandUpdate(hands)) => {
- let colors: Vec<String> = hands.iter().enumerate()
- .filter(|(_, hand)| hand.len() > 0)
+async fn handle_play_message(
+ message: IncomingPlayMessage,
+ sender_tx: mpsc::Sender<OutgoingPlayMessage>,
+ state: &Arc<Mutex<PlayState>>,
+ app_state: &Arc<AppState>,
+) -> Result<(), Error> {
+ match message {
+ IncomingPlayMessage::Initialize { id } => {
+ let (colors, update_rx) = {
+ let sessions = app_state.sessions.read().unwrap();
+ let session = sessions.get(&id).map(|session| session.lock().unwrap());
+ // The Option is unwrapped with let else here instead of ok_or and propagating
+ // because the error moves id which would be used later, and the compiler does not
+ // reason about propagation returning early.
+ let Some(session) = session else {
+ return Err(Error::InvalidSession(id));
+ };
+
+ let colors: Vec<String> = session
+ .seats
+ .iter()
+ .enumerate()
+ .filter(|(_, hand)| !hand.is_empty())
.flat_map(|(index, _)| PlayerColor::try_from(index).ok())
.map(|color| String::from(color.as_ref()))
.collect();
- let _ = sender_tx
- .send(OutgoingPlayMessage::Initialize {
- colors,
- })
- .await;
- let hand = {
- let color = player_color.read().unwrap();
- hands[usize::from(&*color)].to_owned()
- };
- let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await;
+ let update_rx = session.update_tx.subscribe();
+
+ (colors, update_rx)
+ };
+
+ let update_cancel_rx = {
+ let mut state = state.lock().unwrap();
+ let (update_cancel_tx, update_cancel_rx) = oneshot::channel();
+ let _ = mem::replace(&mut state.update_cancel_tx, update_cancel_tx).send(());
+ state.session = Some(id);
+
+ update_cancel_rx
+ };
+ {
+ let sender_tx = sender_tx.clone();
+ let state = state.clone();
+
+ tokio::spawn(async move {
+ handle_update(update_rx, update_cancel_rx, sender_tx, state).await;
+ });
}
- Err(RecvError::Closed) => break,
- Err(RecvError::Lagged(_)) => continue,
+
+ sender_tx
+ .send(OutgoingPlayMessage::Initialize { colors })
+ .await
+ .map_err(Error::from)
+ }
+ IncomingPlayMessage::Color(color) => {
+ let hand = {
+ let mut state = state.lock().unwrap();
+ state.color =
+ PlayerColor::try_from(color.as_str()).map_err(|_| Error::InvalidColor)?;
+ let name = state
+ .session
+ .as_ref()
+ .ok_or_else(|| Error::InvalidSession(String::default()))?;
+ let sessions = app_state.sessions.read().unwrap();
+ let session = sessions
+ .get(name)
+ .ok_or(Error::InvalidSession(name.clone()))?;
+
+ session.lock().unwrap().seats[&state.color].clone()
+ };
+
+ sender_tx
+ .send(OutgoingPlayMessage::Hand(hand))
+ .await
+ .map_err(Error::from)
+ }
+ }
+}
+
+async fn handle_update(
+ mut update_rx: broadcast::Receiver<PlayUpdate>,
+ mut cancel_rx: oneshot::Receiver<()>,
+ sender_tx: mpsc::Sender<OutgoingPlayMessage>,
+ state: Arc<Mutex<PlayState>>,
+) {
+ loop {
+ tokio::select! {
+ update = update_rx.recv() => match update {
+ Ok(PlayUpdate::HandUpdate(hands)) => {
+ let colors: Vec<String> = hands.iter().enumerate()
+ .filter(|(_, hand)| !hand.is_empty())
+ .flat_map(|(index, _)| PlayerColor::try_from(index).ok())
+ .map(|color| String::from(color.as_ref()))
+ .collect();
+ let _ = sender_tx
+ .send(OutgoingPlayMessage::Initialize {
+ colors,
+ })
+ .await;
+ let hand = {
+ let color = &state.lock().unwrap().color;
+ hands[color].to_owned()
+ };
+ let _ = sender_tx.send(OutgoingPlayMessage::Hand(hand)).await;
+ }
+ Err(broadcast::error::RecvError::Closed) => break,
+ Err(broadcast::error::RecvError::Lagged(_)) => continue,
+ },
+ _ = &mut cancel_rx => break,
}
}
}
diff --git a/src/session.rs b/src/session.rs
index 096ae1a..f4c8fe3 100644
--- a/src/session.rs
+++ b/src/session.rs
@@ -1,6 +1,6 @@
-use std::array;
use crate::play::PlayUpdate;
use serde::{Deserialize, Serialize};
+use std::array;
use std::ops::Index;
use tokio::sync::broadcast;
@@ -153,7 +153,7 @@ impl TryFrom<&str> for PlayerColor {
"Pink" => Ok(Self::Pink),
"Grey" => Ok(Self::Grey),
"Black" => Ok(Self::Black),
- _ => Err(())
+ _ => Err(()),
}
}
}
@@ -175,7 +175,7 @@ impl TryFrom<usize> for PlayerColor {
9 => Ok(Self::Pink),
10 => Ok(Self::Grey),
11 => Ok(Self::Black),
- _ => Err(())
+ _ => Err(()),
}
}
}