diff --git a/src/rust/lqosd/src/node_manager/ws.rs b/src/rust/lqosd/src/node_manager/ws.rs index ac8904b1..ba674e0a 100644 --- a/src/rust/lqosd/src/node_manager/ws.rs +++ b/src/rust/lqosd/src/node_manager/ws.rs @@ -11,6 +11,7 @@ //! //! Both types of websocket are authenticated using the auth layer. +use std::collections::HashSet; use std::str::FromStr; use std::sync::Arc; @@ -64,12 +65,13 @@ async fn handle_socket(mut socket: WebSocket, channels: Arc) { log::info!("Websocket connected"); let (tx, mut rx) = tokio::sync::mpsc::channel::(10); + let mut subscribed_channels = HashSet::new(); loop { tokio::select! { inbound = socket.recv() => { // Received a websocket message match inbound { - Some(Ok(msg)) => receive_channel_message(msg, channels.clone(), tx.clone()).await, + Some(Ok(msg)) => receive_channel_message(msg, channels.clone(), tx.clone(), &mut subscribed_channels).await, Some(Err(_)) => break, // The channel has closed None => break, // The channel has closed } @@ -96,12 +98,15 @@ async fn handle_socket(mut socket: WebSocket, channels: Arc) { log::info!("Websocket disconnected"); } -async fn receive_channel_message(msg: Message, channels: Arc, tx: Sender) { +async fn receive_channel_message(msg: Message, channels: Arc, tx: Sender, subscribed_channels: &mut HashSet) { log::debug!("Received message: {:?}", msg); if let Ok(text) = msg.to_text() { if let Ok(sub) = serde_json::from_str::(text) { if let Ok(channel) = PublishedChannels::from_str(&sub.channel) { - channels.subscribe(channel, tx.clone()).await; + if !subscribed_channels.contains(&channel) { + channels.subscribe(channel, tx.clone()).await; + subscribed_channels.insert(channel); + } } } } diff --git a/src/rust/lqosd/src/node_manager/ws/published_channels.rs b/src/rust/lqosd/src/node_manager/ws/published_channels.rs index c7e5a256..e848c16f 100644 --- a/src/rust/lqosd/src/node_manager/ws/published_channels.rs +++ b/src/rust/lqosd/src/node_manager/ws/published_channels.rs @@ -1,6 +1,6 @@ use strum::{Display, EnumIter, EnumString}; -#[derive(PartialEq, Clone, Copy, Debug, EnumIter, Display, EnumString)] +#[derive(PartialEq, Clone, Copy, Debug, EnumIter, Display, EnumString, Hash, Eq)] pub enum PublishedChannels { /// Provides a 1-second tick notification to the client Cadence,