Prevent duplicate channel subscription.

This commit is contained in:
Herbert Wolverson 2024-07-19 08:52:45 -05:00
parent bcd7842755
commit 87257df5ad
2 changed files with 9 additions and 4 deletions

View File

@ -11,6 +11,7 @@
//!
//! Both types of websocket are authenticated using the auth layer.
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
@ -64,12 +65,13 @@ async fn handle_socket(mut socket: WebSocket, channels: Arc<PubSub>) {
log::info!("Websocket connected");
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(10);
let mut subscribed_channels = HashSet::new();
loop {
tokio::select! {
inbound = socket.recv() => {
// Received a websocket message
match inbound {
Some(Ok(msg)) => receive_channel_message(msg, channels.clone(), tx.clone()).await,
Some(Ok(msg)) => receive_channel_message(msg, channels.clone(), tx.clone(), &mut subscribed_channels).await,
Some(Err(_)) => break, // The channel has closed
None => break, // The channel has closed
}
@ -96,12 +98,15 @@ async fn handle_socket(mut socket: WebSocket, channels: Arc<PubSub>) {
log::info!("Websocket disconnected");
}
async fn receive_channel_message(msg: Message, channels: Arc<PubSub>, tx: Sender<String>) {
async fn receive_channel_message(msg: Message, channels: Arc<PubSub>, tx: Sender<String>, subscribed_channels: &mut HashSet<PublishedChannels>) {
log::debug!("Received message: {:?}", msg);
if let Ok(text) = msg.to_text() {
if let Ok(sub) = serde_json::from_str::<Subscribe>(text) {
if let Ok(channel) = PublishedChannels::from_str(&sub.channel) {
if !subscribed_channels.contains(&channel) {
channels.subscribe(channel, tx.clone()).await;
subscribed_channels.insert(channel);
}
}
}
}

View File

@ -1,6 +1,6 @@
use strum::{Display, EnumIter, EnumString};
#[derive(PartialEq, Clone, Copy, Debug, EnumIter, Display, EnumString)]
#[derive(PartialEq, Clone, Copy, Debug, EnumIter, Display, EnumString, Hash, Eq)]
pub enum PublishedChannels {
/// Provides a 1-second tick notification to the client
Cadence,