diff --git a/server/channels/app/platform/web_hub.go b/server/channels/app/platform/web_hub.go index 337e283155..25f85b643b 100644 --- a/server/channels/app/platform/web_hub.go +++ b/server/channels/app/platform/web_hub.go @@ -521,7 +521,7 @@ func (h *Hub) Start() { } if connID := msg.GetBroadcast().ConnectionId; connID != "" { - if webConn := connIndex.byConnectionId[connID]; webConn != nil { + if webConn := connIndex.ForConnection(connID); webConn != nil { broadcast(webConn) continue } @@ -637,7 +637,22 @@ func (i *hubConnectionIndex) Has(wc *WebConn) bool { // ForUser returns all connections for a user ID. func (i *hubConnectionIndex) ForUser(id string) []*WebConn { - return i.byUserId[id] + // Fast path if there is only one or fewer connection. + if len(i.byUserId[id]) <= 1 { + return i.byUserId[id] + } + // If there are multiple connections per user, + // then we have to return a clone of the slice + // to allow connIndex.Remove to be safely called while + // iterating the slice. + conns := make([]*WebConn, len(i.byUserId[id])) + copy(conns, i.byUserId[id]) + return conns +} + +// ForConnection returns the connection from its ID. +func (i *hubConnectionIndex) ForConnection(id string) *WebConn { + return i.byConnectionId[id] } // All returns the full webConn index. diff --git a/server/channels/app/platform/web_hub_test.go b/server/channels/app/platform/web_hub_test.go index 6e7b215edb..c2a9716874 100644 --- a/server/channels/app/platform/web_hub_test.go +++ b/server/channels/app/platform/web_hub_test.go @@ -308,6 +308,53 @@ func TestHubConnIndex(t *testing.T) { }) } +func TestHubConnIndexIncorrectRemoval(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + connIndex := newHubConnectionIndex(1 * time.Second) + + // User2 + wc2 := &WebConn{ + Platform: th.Service, + Suite: th.Suite, + UserId: model.NewId(), + } + wc2.SetConnectionID("first") + wc2.SetSession(&model.Session{}) + + wc3 := &WebConn{ + Platform: th.Service, + Suite: th.Suite, + UserId: wc2.UserId, + } + wc3.SetConnectionID("myID") + wc3.SetSession(&model.Session{}) + + wc4 := &WebConn{ + Platform: th.Service, + Suite: th.Suite, + UserId: wc2.UserId, + } + wc4.SetConnectionID("last") + wc4.SetSession(&model.Session{}) + + connIndex.Add(wc2) + connIndex.Add(wc3) + connIndex.Add(wc4) + + for _, wc := range connIndex.ForUser(wc2.UserId) { + if !connIndex.Has(wc) { + require.Failf(t, "Failed to find connection", "connection: %v", wc) + continue + } + + if connIndex.ForConnection("myID") != nil { + connIndex.Remove(wc) + } + } +} + func TestHubConnIndexByConnectionId(t *testing.T) { th := Setup(t) defer th.TearDown() @@ -355,18 +402,18 @@ func TestHubConnIndexByConnectionId(t *testing.T) { connIndex.Add(wc3) assert.Len(t, connIndex.byConnectionId, 2) - assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID]) - assert.Equal(t, wc3, connIndex.byConnectionId[wc3ID]) - assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID]) + assert.Equal(t, wc1, connIndex.ForConnection(wc1ID)) + assert.Equal(t, wc3, connIndex.ForConnection(wc3ID)) + assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID)) }) t.Run("removing", func(t *testing.T) { connIndex.Remove(wc3) assert.Len(t, connIndex.byConnectionId, 1) - assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID]) - assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc3ID]) - assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID]) + assert.Equal(t, wc1, connIndex.ForConnection(wc1ID)) + assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID)) + assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID)) }) }