MM-56260: connIndex: safely remove conns while iterating (#25785)

PR https://github.com/mattermost/mattermost/pull/22560
introduced a regression in the case where we had
multiple connections for a single user.

Because if connIndex.Remove was called during the iteration
from connIndex.ForUser, then the slice would be modified
during iteration and if a connection got removed,
then a good connection would move from the last index
to the current index. But since we would be actively
iterating, the last index would be read as nil and we would
never be able to reach the good connection.

We fix this by returning a copy of the original slice
if there are more than one elements in the slice.

https://mattermost.atlassian.net/browse/MM-56260
```release-note
Fix a bug where if there were multiple websocket
connections from a single user, then in case one connection
got removed during a broadcast, there was a possibility
that the other good connection might not get the event.
```
This commit is contained in:
Agniva De Sarker
2024-01-05 01:06:22 +05:30
committed by GitHub
parent 59549653a7
commit ca94577cd5
2 changed files with 70 additions and 8 deletions

View File

@@ -521,7 +521,7 @@ func (h *Hub) Start() {
}
if connID := msg.GetBroadcast().ConnectionId; connID != "" {
if webConn := connIndex.byConnectionId[connID]; webConn != nil {
if webConn := connIndex.ForConnection(connID); webConn != nil {
broadcast(webConn)
continue
}
@@ -637,7 +637,22 @@ func (i *hubConnectionIndex) Has(wc *WebConn) bool {
// ForUser returns all connections for a user ID.
func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
return i.byUserId[id]
// Fast path if there is only one or fewer connection.
if len(i.byUserId[id]) <= 1 {
return i.byUserId[id]
}
// If there are multiple connections per user,
// then we have to return a clone of the slice
// to allow connIndex.Remove to be safely called while
// iterating the slice.
conns := make([]*WebConn, len(i.byUserId[id]))
copy(conns, i.byUserId[id])
return conns
}
// ForConnection returns the connection from its ID.
func (i *hubConnectionIndex) ForConnection(id string) *WebConn {
return i.byConnectionId[id]
}
// All returns the full webConn index.

View File

@@ -308,6 +308,53 @@ func TestHubConnIndex(t *testing.T) {
})
}
func TestHubConnIndexIncorrectRemoval(t *testing.T) {
th := Setup(t)
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
// User2
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc2.SetConnectionID("first")
wc2.SetSession(&model.Session{})
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID("myID")
wc3.SetSession(&model.Session{})
wc4 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc4.SetConnectionID("last")
wc4.SetSession(&model.Session{})
connIndex.Add(wc2)
connIndex.Add(wc3)
connIndex.Add(wc4)
for _, wc := range connIndex.ForUser(wc2.UserId) {
if !connIndex.Has(wc) {
require.Failf(t, "Failed to find connection", "connection: %v", wc)
continue
}
if connIndex.ForConnection("myID") != nil {
connIndex.Remove(wc)
}
}
}
func TestHubConnIndexByConnectionId(t *testing.T) {
th := Setup(t)
defer th.TearDown()
@@ -355,18 +402,18 @@ func TestHubConnIndexByConnectionId(t *testing.T) {
connIndex.Add(wc3)
assert.Len(t, connIndex.byConnectionId, 2)
assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID])
assert.Equal(t, wc3, connIndex.byConnectionId[wc3ID])
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID])
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
t.Run("removing", func(t *testing.T) {
connIndex.Remove(wc3)
assert.Len(t, connIndex.byConnectionId, 1)
assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID])
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc3ID])
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID])
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
}