mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-56260: connIndex: safely remove conns while iterating (#25785)
PR https://github.com/mattermost/mattermost/pull/22560 introduced a regression in the case where we had multiple connections for a single user. Because if connIndex.Remove was called during the iteration from connIndex.ForUser, then the slice would be modified during iteration and if a connection got removed, then a good connection would move from the last index to the current index. But since we would be actively iterating, the last index would be read as nil and we would never be able to reach the good connection. We fix this by returning a copy of the original slice if there are more than one elements in the slice. https://mattermost.atlassian.net/browse/MM-56260 ```release-note Fix a bug where if there were multiple websocket connections from a single user, then in case one connection got removed during a broadcast, there was a possibility that the other good connection might not get the event. ```
This commit is contained in:
@@ -521,7 +521,7 @@ func (h *Hub) Start() {
|
||||
}
|
||||
|
||||
if connID := msg.GetBroadcast().ConnectionId; connID != "" {
|
||||
if webConn := connIndex.byConnectionId[connID]; webConn != nil {
|
||||
if webConn := connIndex.ForConnection(connID); webConn != nil {
|
||||
broadcast(webConn)
|
||||
continue
|
||||
}
|
||||
@@ -637,7 +637,22 @@ func (i *hubConnectionIndex) Has(wc *WebConn) bool {
|
||||
|
||||
// ForUser returns all connections for a user ID.
|
||||
func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
|
||||
return i.byUserId[id]
|
||||
// Fast path if there is only one or fewer connection.
|
||||
if len(i.byUserId[id]) <= 1 {
|
||||
return i.byUserId[id]
|
||||
}
|
||||
// If there are multiple connections per user,
|
||||
// then we have to return a clone of the slice
|
||||
// to allow connIndex.Remove to be safely called while
|
||||
// iterating the slice.
|
||||
conns := make([]*WebConn, len(i.byUserId[id]))
|
||||
copy(conns, i.byUserId[id])
|
||||
return conns
|
||||
}
|
||||
|
||||
// ForConnection returns the connection from its ID.
|
||||
func (i *hubConnectionIndex) ForConnection(id string) *WebConn {
|
||||
return i.byConnectionId[id]
|
||||
}
|
||||
|
||||
// All returns the full webConn index.
|
||||
|
||||
@@ -308,6 +308,53 @@ func TestHubConnIndex(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHubConnIndexIncorrectRemoval(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
|
||||
// User2
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc2.SetConnectionID("first")
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID("myID")
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
wc4 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc4.SetConnectionID("last")
|
||||
wc4.SetSession(&model.Session{})
|
||||
|
||||
connIndex.Add(wc2)
|
||||
connIndex.Add(wc3)
|
||||
connIndex.Add(wc4)
|
||||
|
||||
for _, wc := range connIndex.ForUser(wc2.UserId) {
|
||||
if !connIndex.Has(wc) {
|
||||
require.Failf(t, "Failed to find connection", "connection: %v", wc)
|
||||
continue
|
||||
}
|
||||
|
||||
if connIndex.ForConnection("myID") != nil {
|
||||
connIndex.Remove(wc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubConnIndexByConnectionId(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
@@ -355,18 +402,18 @@ func TestHubConnIndexByConnectionId(t *testing.T) {
|
||||
connIndex.Add(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 2)
|
||||
assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID])
|
||||
assert.Equal(t, wc3, connIndex.byConnectionId[wc3ID])
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID])
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
|
||||
t.Run("removing", func(t *testing.T) {
|
||||
connIndex.Remove(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 1)
|
||||
assert.Equal(t, wc1, connIndex.byConnectionId[wc1ID])
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc3ID])
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.byConnectionId[wc2ID])
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user