[MM-58500] Turn off PostedAck when the connection is no longer registered (#27212)

* [MM-58500] Turn off PostedAck when the connection is no longer registered

* Expose active and just check for active instead
This commit is contained in:
Devin Binnie
2024-06-03 09:44:53 -04:00
committed by GitHub
parent 4163db4e5e
commit 4ec50a7ddd
5 changed files with 42 additions and 22 deletions

View File

@@ -107,7 +107,7 @@ type WebConn struct {
deadQueuePointer int
// active indicates whether there is an open websocket connection attached
// to this webConn or not.
active atomic.Bool
Active atomic.Bool
// reuseCount indicates how many times this connection has been reused.
// This is used to differentiate between a fresh connection and
// a reused connection.
@@ -245,7 +245,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
lastLogTimeFull: time.Now(),
originClient: cfg.OriginClient,
}
wc.active.Store(cfg.Active)
wc.Active.Store(cfg.Active)
wc.SetSession(&cfg.Session)
wc.SetSessionToken(cfg.Session.Token)
@@ -555,7 +555,7 @@ func (wc *WebConn) writePump() {
continue
}
if wc.active.Load() && len(wc.send) >= sendFullWarn && time.Since(wc.lastLogTimeFull) > websocketSuppressWarnThreshold {
if wc.Active.Load() && len(wc.send) >= sendFullWarn && time.Since(wc.lastLogTimeFull) > websocketSuppressWarnThreshold {
logData := []mlog.Field{
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()),
@@ -812,7 +812,7 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
case model.WebsocketEventTyping,
model.WebsocketEventStatusChange,
model.WebsocketEventMultipleChannelsViewed:
if wc.active.Load() && time.Since(wc.lastLogTimeSlow) > websocketSuppressWarnThreshold {
if wc.Active.Load() && time.Since(wc.lastLogTimeSlow) > websocketSuppressWarnThreshold {
mlog.Warn(
"websocket.slow: dropping message",
mlog.String("user_id", wc.UserId),

View File

@@ -389,7 +389,7 @@ func (h *Hub) Start() {
conns := connIndex.ForUser(webSessionMessage.userID)
var isRegistered bool
for _, conn := range conns {
if !conn.active.Load() {
if !conn.Active.Load() {
continue
}
if conn.GetSessionToken() == webSessionMessage.sessionToken {
@@ -419,7 +419,7 @@ func (h *Hub) Start() {
// Mark the current one as active.
// There is no need to check if it was inactive or not,
// we will anyways need to make it active.
webConn.active.Store(true)
webConn.Active.Store(true)
connIndex.Add(webConn)
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
@@ -434,7 +434,7 @@ func (h *Hub) Start() {
case webConn := <-h.unregister:
// If already removed (via queue full), then removing again becomes a noop.
// But if not removed, mark inactive.
webConn.active.Store(false)
webConn.Active.Store(false)
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
@@ -471,7 +471,7 @@ func (h *Hub) Start() {
}
var latestActivity int64
for _, conn := range conns {
if !conn.active.Load() {
if !conn.Active.Load() {
continue
}
if conn.lastUserActivityAt > latestActivity {
@@ -491,7 +491,7 @@ func (h *Hub) Start() {
}
case activity := <-h.activity:
for _, webConn := range connIndex.ForUser(activity.userID) {
if !webConn.active.Load() {
if !webConn.Active.Load() {
continue
}
if webConn.GetSessionToken() == activity.sessionToken {
@@ -506,7 +506,7 @@ func (h *Hub) Start() {
case directMsg.conn.send <- directMsg.msg:
default:
// Don't log the warning if it's an inactive connection.
if directMsg.conn.active.Load() {
if directMsg.conn.Active.Load() {
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
mlog.String("user_id", directMsg.conn.UserId),
mlog.String("conn_id", directMsg.conn.GetConnectionID()))
@@ -533,7 +533,7 @@ func (h *Hub) Start() {
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
default:
// Don't log the warning if it's an inactive connection.
if webConn.active.Load() {
if webConn.Active.Load() {
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
mlog.String("user_id", webConn.UserId),
mlog.String("conn_id", webConn.GetConnectionID()))
@@ -601,7 +601,7 @@ func (h *Hub) Start() {
// are inactive or not.
func areAllInactive(conns []*WebConn) bool {
for _, conn := range conns {
if conn.active.Load() {
if conn.Active.Load() {
return false
}
}
@@ -689,7 +689,7 @@ func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
func (i *hubConnectionIndex) ForUserActiveCount(id string) int {
cnt := 0
for _, conn := range i.ForUser(id) {
if conn.active.Load() {
if conn.Active.Load() {
cnt++
}
}
@@ -714,7 +714,7 @@ func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID s
return nil
}
for _, conn := range i.ForUser(userID) {
if conn.GetConnectionID() == connectionID && !conn.active.Load() {
if conn.GetConnectionID() == connectionID && !conn.Active.Load() {
i.Remove(conn)
return conn
}
@@ -727,7 +727,7 @@ func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID s
func (i *hubConnectionIndex) RemoveInactiveConnections() {
now := model.GetMillis()
for conn := range i.byConnection {
if !conn.active.Load() && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() {
if !conn.Active.Load() && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() {
i.Remove(conn)
}
}
@@ -739,7 +739,7 @@ func (i *hubConnectionIndex) RemoveInactiveConnections() {
func (i *hubConnectionIndex) AllActive() int {
cnt := 0
for conn := range i.byConnection {
if conn.active.Load() {
if conn.Active.Load() {
cnt++
}
}

View File

@@ -429,7 +429,7 @@ func TestHubConnIndexInactive(t *testing.T) {
Platform: th.Service,
UserId: model.NewId(),
}
wc1.active.Store(true)
wc1.Active.Store(true)
wc1.SetConnectionID("conn1")
wc1.SetSession(&model.Session{})
@@ -438,7 +438,7 @@ func TestHubConnIndexInactive(t *testing.T) {
Platform: th.Service,
UserId: model.NewId(),
}
wc2.active.Store(true)
wc2.Active.Store(true)
wc2.SetConnectionID("conn2")
wc2.SetSession(&model.Session{})
@@ -446,7 +446,7 @@ func TestHubConnIndexInactive(t *testing.T) {
Platform: th.Service,
UserId: wc2.UserId,
}
wc3.active.Store(false)
wc3.Active.Store(false)
wc3.SetConnectionID("conn3")
wc3.SetSession(&model.Session{})

View File

@@ -83,7 +83,7 @@ func usePostedAckHook(message *model.WebSocketEvent, postedUserId string, channe
func (h *postedAckBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
// Don't ACK unless we say to explicitly
if !webConn.PostedAck {
if !(webConn.PostedAck && webConn.Active.Load()) {
return nil
}

View File

@@ -91,6 +91,7 @@ func TestPostedAckHook_Process(t *testing.T) {
Platform: &platform.PlatformService{},
PostedAck: true,
}
webConn.Active.Store(true)
webConn.SetSession(&model.Session{})
t.Run("should ack if user is in the list of users to notify", func(t *testing.T) {
@@ -142,14 +143,33 @@ func TestPostedAckHook_Process(t *testing.T) {
})
t.Run("should not ack if posted ack is false", func(t *testing.T) {
mobileWebConn := &platform.WebConn{
noAckWebConn := &platform.WebConn{
UserId: userID,
Platform: &platform.PlatformService{},
PostedAck: false,
}
noAckWebConn.Active.Store(true)
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
hook.Process(msg, mobileWebConn, map[string]any{
hook.Process(msg, noAckWebConn, map[string]any{
"posted_user_id": model.NewId(),
"channel_type": model.ChannelTypeDirect,
"users": []string{},
})
assert.Nil(t, msg.Event().GetData()["should_ack"])
})
t.Run("should not ack if connection is not active", func(t *testing.T) {
inactiveWebConn := &platform.WebConn{
UserId: userID,
Platform: &platform.PlatformService{},
PostedAck: false,
}
inactiveWebConn.Active.Store(true)
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
hook.Process(msg, inactiveWebConn, map[string]any{
"posted_user_id": model.NewId(),
"channel_type": model.ChannelTypeDirect,
"users": []string{},