MM-61130: Use a channelMember map at web_hub level (#28810)

Tests at very high scale indicates that the iteration
of all connections during websocket broadcast starts
to become a bottleneck.

To optimize this, we move the channelMember cache from
inside web_conn.go to the hubConnectionIndex.

This involves adding a new map keyed by the channelID
and containing all webConns where the user is a member
of that channel. Subsequently, a new method needed to
be added to invalidate the cache which previously
used to happen in web_conn.

And as a last step, we remove the cache from web_conn
to reduce SQL queries to the DB.

https://mattermost.atlassian.net/browse/MM-61130

```release-note
NONE
```
This commit is contained in:
Agniva De Sarker 2024-11-08 09:57:54 +05:30 committed by GitHub
parent 37d97e8024
commit bd8774bdce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 564 additions and 329 deletions

View File

@ -16,6 +16,7 @@ import (
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"
@ -27,6 +28,7 @@ import (
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks"
"github.com/mattermost/mattermost/server/v8/channels/testlib"
"github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/channels/utils/testutils"
)
@ -2937,6 +2939,155 @@ func TestPermanentDeletePost(t *testing.T) {
})
}
func TestWebHubMembership(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
u1 := th.CreateUser()
th.LinkUserToTeam(u1, th.BasicTeam)
th.AddUserToChannel(u1, th.BasicChannel)
ch2 := th.CreatePrivateChannel()
u2 := th.CreateUser()
th.LinkUserToTeam(u2, th.BasicTeam)
th.AddUserToChannel(u2, ch2)
quitChan := make(chan struct{})
var wg sync.WaitGroup
wg.Add(3)
for _, obj := range []struct {
testName string
user *model.User
}{
{
testName: "basicUser",
user: th.BasicUser,
},
{
testName: "u1",
user: u1,
},
{
testName: "u2",
user: u2,
},
} {
cli := th.CreateClient()
_, _, err := cli.Login(context.Background(), obj.user.Username, obj.user.Password)
require.NoError(t, err)
wsClient, err := th.CreateWebSocketClientWithClient(cli)
require.NoError(t, err)
defer wsClient.Close()
wsClient.Listen()
go func(testName string) {
defer wg.Done()
var cnt int
for {
select {
case event := <-wsClient.EventChannel:
if event.EventType() == model.WebsocketEventPosted {
var post model.Post
err := json.Unmarshal([]byte(event.GetData()["post"].(string)), &post)
require.NoError(t, err)
cnt++
// Cases:
// Post to basicChannel should go to u1 and basicUser.
// Add u1 to ch2.
// Post to ch2 should go to u1, u2 and basicUser.
// Remove u1 from ch2.
// Post to ch2 should go to u2 and basicUser.
switch testName {
case "basicUser":
if cnt == 1 {
assert.Equal(t, th.BasicChannel.Id, post.ChannelId)
} else if cnt == 2 {
assert.Equal(t, ch2.Id, post.ChannelId)
} else if cnt == 3 {
// After removing, there will be a "removed from channel post"
assert.Equal(t, ch2.Id, post.ChannelId)
} else if cnt == 4 {
assert.Equal(t, ch2.Id, post.ChannelId)
} else {
assert.Fail(t, "more than 4 messages arrived for basicUser")
}
case "u1":
// First msg should be from basicChannel
if cnt == 1 {
assert.Equal(t, th.BasicChannel.Id, post.ChannelId)
} else if cnt == 2 {
// second should be from ch2
assert.Equal(t, ch2.Id, post.ChannelId)
} else {
assert.Fail(t, "more than 2 messages arrived for u1")
}
case "u2":
if cnt == 1 {
assert.Equal(t, ch2.Id, post.ChannelId)
} else if cnt == 2 {
// After removing, there will be a "removed from channel post"
assert.Equal(t, ch2.Id, post.ChannelId)
} else if cnt == 3 {
assert.Equal(t, ch2.Id, post.ChannelId)
} else {
assert.Fail(t, "more than 3 messages arrived for u2")
}
}
}
case <-quitChan:
return
}
}
}(obj.testName)
}
// Will send to basic channel
th.CreatePost()
// Add u1 to ch2
th.AddUserToChannel(u1, ch2)
// Send post to ch2
th.CreatePostWithClient(th.Client, ch2)
// Remove u1 from ch2
th.RemoveUserFromChannel(u1, ch2)
// Send post to ch2
th.CreatePostWithClient(th.Client, ch2)
// It is possible to create a signalling mechanism from the goroutines
// after all events are received, but we also want to verify that no additional
// events are being sent.
time.Sleep(2 * time.Second)
close(quitChan)
wg.Wait()
}
func TestWebHubCloseConnOnDBFail(t *testing.T) {
th := Setup(t).InitBasic()
defer func() {
th.TearDown()
// Asserting that the error message is present in the log
testlib.AssertLog(t, th.LogBuffer, mlog.LvlError.Name, "Error while registering to hub")
_, err := th.Server.Store().GetInternalMasterDB().Exec(`ALTER TABLE dummy RENAME to ChannelMembers`)
require.NoError(t, err)
}()
cli := th.CreateClient()
_, _, err := cli.Login(context.Background(), th.BasicUser.Username, th.BasicUser.Password)
require.NoError(t, err)
_, err = th.Server.Store().GetInternalMasterDB().Exec(`ALTER TABLE ChannelMembers RENAME to dummy`)
require.NoError(t, err)
wsClient, err := th.CreateWebSocketClientWithClient(cli)
require.NoError(t, err)
defer wsClient.Close()
require.NoError(t, th.TestLogger.Flush())
}
func TestDeletePostEvent(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()

View File

@ -70,7 +70,7 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
} else {
cfg, err = c.App.Srv().Platform().PopulateWebConnConfig(c.AppContext.Session(), cfg, r.URL.Query().Get(sequenceNumberParam))
if err != nil {
c.Logger.Warn("Error while populating webconn config", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
c.Logger.Error("Error while populating webconn config", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
ws.Close()
return
}
@ -78,7 +78,12 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
wc := c.App.Srv().Platform().NewWebConn(cfg, c.App, c.App.Srv().Channels())
if c.AppContext.Session().UserId != "" {
c.App.Srv().Platform().HubRegister(wc)
err = c.App.Srv().Platform().HubRegister(wc)
if err != nil {
c.Logger.Error("Error while registering to hub", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
ws.Close()
return
}
}
wc.Pump()

View File

@ -249,10 +249,6 @@ type AppIface interface {
GetUserStatusesByIds(userIDs []string) ([]*model.Status, *model.AppError)
// HasRemote returns whether a given channelID is present in the channel remotes or not.
HasRemote(channelID string, remoteID string) (bool, error)
// HubRegister registers a connection to a hub.
HubRegister(webConn *platform.WebConn)
// HubUnregister unregisters a connection from a hub.
HubUnregister(webConn *platform.WebConn)
// InstallPlugin unpacks and installs a plugin but does not enable or activate it unless the the
// plugin was already enabled.
InstallPlugin(pluginFile io.ReadSeeker, replace bool) (*model.Manifest, *model.AppError)

View File

@ -523,7 +523,7 @@ func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectio
Locale: "en",
ConnectionID: connectionID,
}, th.App, th.App.Channels())
th.App.Srv().Platform().HubRegister(webConn)
require.NoError(t, th.App.Srv().Platform().HubRegister(webConn))
// Start reading from it
go webConn.Pump()

View File

@ -11982,36 +11982,6 @@ func (a *OpenTracingAppLayer) HasSharedChannel(channelID string) (bool, error) {
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) HubRegister(webConn *platform.WebConn) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HubRegister")
a.ctx = newCtx
a.app.Srv().Store().SetContext(newCtx)
defer func() {
a.app.Srv().Store().SetContext(origCtx)
a.ctx = origCtx
}()
defer span.Finish()
a.app.HubRegister(webConn)
}
func (a *OpenTracingAppLayer) HubUnregister(webConn *platform.WebConn) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HubUnregister")
a.ctx = newCtx
a.app.Srv().Store().SetContext(newCtx)
defer func() {
a.app.Srv().Store().SetContext(origCtx)
a.ctx = origCtx
}()
defer span.Finish()
a.app.HubUnregister(webConn)
}
func (a *OpenTracingAppLayer) IPFiltering() einterfaces.IPFilteringInterface {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.IPFiltering")

View File

@ -26,7 +26,6 @@ import (
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
)
const (
@ -95,10 +94,8 @@ type WebConn struct {
UserId string
PostedAck bool
allChannelMembers map[string]string
lastAllChannelMembersTime int64
lastUserActivityAt int64
send chan model.WebSocketMessage
lastUserActivityAt int64
send chan model.WebSocketMessage
// deadQueue behaves like a queue of a finite size
// which is used to store all messages that are sent via the websocket.
// It basically acts as the user-space socket buffer, and is used
@ -218,7 +215,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
if tcpConn != nil {
err := tcpConn.SetNoDelay(false)
if err != nil {
mlog.Warn("Error in setting NoDelay socket opts", mlog.Err(err))
ps.logger.Warn("Error in setting NoDelay socket opts", mlog.Err(err))
}
}
@ -321,6 +318,9 @@ func (wc *WebConn) SetConnectionID(id string) {
// GetConnectionID returns the connection id of the connection.
func (wc *WebConn) GetConnectionID() string {
if wc.connectionID.Load() == nil {
return ""
}
return wc.connectionID.Load().(string)
}
@ -566,7 +566,7 @@ func (wc *WebConn) writePump() {
err = enc.Encode(msg)
}
if err != nil {
mlog.Warn("Error in encoding websocket message", mlog.Err(err))
wc.Platform.logger.Warn("Error in encoding websocket message", mlog.Err(err))
continue
}
@ -581,7 +581,7 @@ func (wc *WebConn) writePump() {
logData = append(logData, mlog.String("channel_id", evt.GetBroadcast().ChannelId))
}
mlog.Warn("websocket.full", logData...)
wc.Platform.logger.Warn("websocket.full", logData...)
wc.lastLogTimeFull = time.Now()
}
@ -608,7 +608,7 @@ func (wc *WebConn) writePump() {
case <-authTicker.C:
if wc.GetSessionToken() == "" {
mlog.Debug("websocket.authTicker: did not authenticate", mlog.Stringer("ip_address", wc.WebSocket.RemoteAddr()))
wc.Platform.logger.Debug("websocket.authTicker: did not authenticate", mlog.Stringer("ip_address", wc.WebSocket.RemoteAddr()))
return
}
authTicker.Stop()
@ -629,7 +629,7 @@ func (wc *WebConn) writeMessage(msg *model.WebSocketEvent) error {
var buf bytes.Buffer
err := msg.Encode(json.NewEncoder(&buf), &buf)
if err != nil {
mlog.Warn("Error in encoding websocket message", mlog.Err(err))
wc.Platform.logger.Warn("Error in encoding websocket message", mlog.Err(err))
return nil
}
wc.Sequence++
@ -734,8 +734,6 @@ func (wc *WebConn) drainDeadQueue(index int) error {
// InvalidateCache resets all internal data of the WebConn.
func (wc *WebConn) InvalidateCache() {
wc.allChannelMembers = nil
wc.lastAllChannelMembersTime = 0
wc.SetSession(nil)
wc.SetSessionExpiresAt(0)
}
@ -751,9 +749,9 @@ func (wc *WebConn) IsAuthenticated() bool {
session, err := wc.Suite.GetSession(wc.GetSessionToken())
if err != nil {
if err.StatusCode >= http.StatusBadRequest && err.StatusCode < http.StatusInternalServerError {
mlog.Debug("Invalid session.", mlog.Err(err))
wc.Platform.logger.Debug("Invalid session.", mlog.Err(err))
} else {
mlog.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
wc.Platform.logger.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
}
wc.SetSessionToken("")
@ -789,7 +787,7 @@ func (wc *WebConn) ShouldSendEventToGuest(msg *model.WebSocketEvent) bool {
case model.WebsocketEventUserUpdated:
user, ok := msg.GetData()["user"].(*model.User)
if !ok {
mlog.Debug("webhub.shouldSendEvent: user not found in message", mlog.Any("user", msg.GetData()["user"]))
wc.Platform.logger.Debug("webhub.shouldSendEvent: user not found in message", mlog.Any("user", msg.GetData()["user"]))
return false
}
userID = user.Id
@ -828,7 +826,7 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
model.WebsocketEventStatusChange,
model.WebsocketEventMultipleChannelsViewed:
if wc.Active.Load() && time.Since(wc.lastLogTimeSlow) > websocketSuppressWarnThreshold {
mlog.Warn(
wc.Platform.logger.Warn(
"websocket.slow: dropping message",
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()),
@ -893,8 +891,8 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
// Only report events to users who are in the channel for the event
if chID := msg.GetBroadcast().ChannelId; chID != "" {
// For typing events, we don't send them to users who don't have
// that channel or thread opened.
// For typing/reaction_added/reaction_removed events, we don't send them to users
// who don't have that channel or thread opened.
if wc.Platform.Config().FeatureFlags.WebSocketEventScope &&
slices.Contains([]model.WebsocketEventType{
model.WebsocketEventTyping,
@ -904,30 +902,9 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
return false
}
if model.GetMillis()-wc.lastAllChannelMembersTime > webConnMemberCacheTime {
wc.allChannelMembers = nil
wc.lastAllChannelMembersTime = 0
}
if wc.allChannelMembers == nil {
result, err := wc.Platform.Store.Channel().GetAllChannelMembersForUser(
sqlstore.RequestContextWithMaster(request.EmptyContext(wc.Platform.logger)),
wc.UserId,
false,
false,
)
if err != nil {
mlog.Error("webhub.shouldSendEvent.", mlog.Err(err))
return false
}
wc.allChannelMembers = result
wc.lastAllChannelMembersTime = model.GetMillis()
}
if _, ok := wc.allChannelMembers[chID]; ok {
return true
}
return false
// We don't need to do any further checks because this is already scoped
// to channel members from web_hub.
return true
}
// Only report events to users who are in the team for the event
@ -960,9 +937,9 @@ func (wc *WebConn) isMemberOfTeam(teamID string) bool {
session, err := wc.Suite.GetSession(wc.GetSessionToken())
if err != nil {
if err.StatusCode >= http.StatusBadRequest && err.StatusCode < http.StatusInternalServerError {
mlog.Debug("Invalid session.", mlog.Err(err))
wc.Platform.logger.Debug("Invalid session.", mlog.Err(err))
} else {
mlog.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
wc.Platform.logger.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
}
return false
}
@ -976,12 +953,12 @@ func (wc *WebConn) isMemberOfTeam(teamID string) bool {
func (wc *WebConn) logSocketErr(source string, err error) {
// browsers will appear as CloseNoStatusReceived
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
mlog.Debug(source+": client side closed socket",
wc.Platform.logger.Debug(source+": client side closed socket",
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()),
mlog.String("origin_client", wc.originClient))
} else {
mlog.Debug(source+": closing websocket",
wc.Platform.logger.Debug(source+": closing websocket",
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()),
mlog.String("origin_client", wc.originClient),

View File

@ -4,6 +4,7 @@
package platform
import (
"fmt"
"hash/maphash"
"runtime"
"runtime/debug"
@ -14,6 +15,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store"
)
const (
@ -45,6 +47,11 @@ type webConnSessionMessage struct {
isRegistered chan bool
}
type webConnRegisterMessage struct {
conn *WebConn
err chan error
}
type webConnCheckMessage struct {
userID string
connectionID string
@ -65,7 +72,7 @@ type Hub struct {
connectionCount int64
platform *PlatformService
connectionIndex int
register chan *WebConn
register chan *webConnRegisterMessage
unregister chan *WebConn
broadcast chan *model.WebSocketEvent
stop chan struct{}
@ -84,7 +91,7 @@ type Hub struct {
func newWebHub(ps *PlatformService) *Hub {
return &Hub{
platform: ps,
register: make(chan *WebConn),
register: make(chan *webConnRegisterMessage),
unregister: make(chan *WebConn),
broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize),
stop: make(chan struct{}),
@ -150,14 +157,15 @@ func (ps *PlatformService) GetHubForUserId(userID string) *Hub {
}
// HubRegister registers a connection to a hub.
func (ps *PlatformService) HubRegister(webConn *WebConn) {
func (ps *PlatformService) HubRegister(webConn *WebConn) error {
hub := ps.GetHubForUserId(webConn.UserId)
if hub != nil {
if metrics := ps.metricsIFace; metrics != nil {
metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
}
hub.Register(webConn)
return hub.Register(webConn)
}
return nil
}
// HubUnregister unregisters a connection from a hub.
@ -262,11 +270,17 @@ func (ps *PlatformService) WebConnCountForUser(userID string) int {
}
// Register registers a connection to the hub.
func (h *Hub) Register(webConn *WebConn) {
func (h *Hub) Register(webConn *WebConn) error {
wr := &webConnRegisterMessage{
conn: webConn,
err: make(chan error),
}
select {
case h.register <- webConn:
case h.register <- wr:
return <-wr.err
case <-h.stop:
}
return nil
}
// Unregister unregisters a connection from the hub.
@ -389,7 +403,10 @@ func (h *Hub) Start() {
ticker := time.NewTicker(inactiveConnReaperInterval)
defer ticker.Stop()
connIndex := newHubConnectionIndex(inactiveConnReaperInterval)
connIndex := newHubConnectionIndex(inactiveConnReaperInterval,
h.platform.Store,
h.platform.logger,
)
for {
select {
@ -423,22 +440,27 @@ func (h *Hub) Start() {
req.result <- connIndex.ForUserActiveCount(req.userID)
case <-ticker.C:
connIndex.RemoveInactiveConnections()
case webConn := <-h.register:
case webConnReg := <-h.register:
// Mark the current one as active.
// There is no need to check if it was inactive or not,
// we will anyways need to make it active.
webConn.Active.Store(true)
webConnReg.conn.Active.Store(true)
connIndex.Add(webConn)
err := connIndex.Add(webConnReg.conn)
if err != nil {
webConnReg.err <- err
continue
}
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
if webConn.IsAuthenticated() && webConn.reuseCount == 0 {
if webConnReg.conn.IsAuthenticated() && webConnReg.conn.reuseCount == 0 {
// The hello message should only be sent when the reuseCount is 0.
// i.e in server restart, or long timeout, or fresh connection case.
// In case of seq number not found in dead queue, it is handled by
// the webconn write pump.
webConn.send <- webConn.createHelloMessage()
webConnReg.conn.send <- webConnReg.conn.createHelloMessage()
}
webConnReg.err <- nil
case webConn := <-h.unregister:
// If already removed (via queue full), then removing again becomes a noop.
// But if not removed, mark inactive.
@ -497,6 +519,13 @@ func (h *Hub) Start() {
for _, webConn := range connIndex.ForUser(userID) {
webConn.InvalidateCache()
}
err := connIndex.InvalidateCMCacheForUser(userID)
if err != nil {
h.platform.Log().Error("Error while invalidating channel member cache", mlog.String("user_id", userID), mlog.Err(err))
for _, webConn := range connIndex.ForUser(userID) {
closeAndRemoveConn(connIndex, webConn)
}
}
case activity := <-h.activity:
for _, webConn := range connIndex.ForUser(activity.userID) {
if !webConn.Active.Load() {
@ -519,8 +548,7 @@ func (h *Hub) Start() {
mlog.String("user_id", directMsg.conn.UserId),
mlog.String("conn_id", directMsg.conn.GetConnectionID()))
}
close(directMsg.conn.send)
connIndex.Remove(directMsg.conn)
closeAndRemoveConn(connIndex, directMsg.conn)
}
case msg := <-h.broadcast:
if metrics := h.platform.metricsIFace; metrics != nil {
@ -546,27 +574,29 @@ func (h *Hub) Start() {
mlog.String("user_id", webConn.UserId),
mlog.String("conn_id", webConn.GetConnectionID()))
}
close(webConn.send)
connIndex.Remove(webConn)
closeAndRemoveConn(connIndex, webConn)
}
}
}
var targetConns []*WebConn
if connID := msg.GetBroadcast().ConnectionId; connID != "" {
if webConn := connIndex.ForConnection(connID); webConn != nil {
broadcast(webConn)
continue
targetConns = append(targetConns, webConn)
}
} else if msg.GetBroadcast().UserId != "" {
candidates := connIndex.ForUser(msg.GetBroadcast().UserId)
for _, webConn := range candidates {
} else if userID := msg.GetBroadcast().UserId; userID != "" {
targetConns = connIndex.ForUser(userID)
} else if channelID := msg.GetBroadcast().ChannelId; channelID != "" {
targetConns = connIndex.ForChannel(channelID)
}
if targetConns != nil {
for _, webConn := range targetConns {
broadcast(webConn)
}
continue
}
candidates := connIndex.All()
for webConn := range candidates {
for webConn := range connIndex.All() {
broadcast(webConn)
}
case <-h.stop:
@ -616,14 +646,24 @@ func areAllInactive(conns []*WebConn) bool {
return true
}
// closeAndRemoveConn closes the send channel which will close the
// websocket connection, and then it removes the webConn from the conn index.
func closeAndRemoveConn(connIndex *hubConnectionIndex, conn *WebConn) {
close(conn.send)
connIndex.Remove(conn)
}
// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
// It requires 3 functionalities which need to be very fast:
// It requires 4 functionalities which need to be very fast:
// - check if a connection exists or not.
// - get all connections for a given userID.
// - get all connections for a given channelID.
// - get all connections.
type hubConnectionIndex struct {
// byUserId stores the list of connections for a given userID
byUserId map[string][]*WebConn
// byChannelID stores the list of connections for a given channelID.
byChannelID map[string][]*WebConn
// byConnection serves the dual purpose of storing the index of the webconn
// in the value of byUserId map, and also to get all connections.
byConnection map[*WebConn]int
@ -631,21 +671,39 @@ type hubConnectionIndex struct {
// staleThreshold is the limit beyond which inactive connections
// will be deleted.
staleThreshold time.Duration
store store.Store
logger mlog.LoggerIFace
}
func newHubConnectionIndex(interval time.Duration) *hubConnectionIndex {
func newHubConnectionIndex(interval time.Duration,
store store.Store,
logger mlog.LoggerIFace,
) *hubConnectionIndex {
return &hubConnectionIndex{
byUserId: make(map[string][]*WebConn),
byChannelID: make(map[string][]*WebConn),
byConnection: make(map[*WebConn]int),
byConnectionId: make(map[string]*WebConn),
staleThreshold: interval,
store: store,
logger: logger,
}
}
func (i *hubConnectionIndex) Add(wc *WebConn) {
func (i *hubConnectionIndex) Add(wc *WebConn) error {
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), wc.UserId, false, false)
if err != nil {
return fmt.Errorf("error getChannelMembersForUser: %v", err)
}
for chID := range cm {
i.byChannelID[chID] = append(i.byChannelID[chID], wc)
}
i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc)
i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1
i.byConnectionId[wc.GetConnectionID()] = wc
return nil
}
func (i *hubConnectionIndex) Remove(wc *WebConn) {
@ -654,21 +712,67 @@ func (i *hubConnectionIndex) Remove(wc *WebConn) {
return
}
// Remove the wc from i.byUserId
// get the conn slice.
userConnections := i.byUserId[wc.UserId]
// get the last connection.
last := userConnections[len(userConnections)-1]
// set the slot that we are trying to remove to be the last connection.
// https://go.dev/wiki/SliceTricks#delete-without-preserving-order
userConnections[userConnIndex] = last
// remove the last connection pointer from slice.
userConnections[len(userConnections)-1] = nil
// remove the last connection from the slice.
i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1]
// set the index of the connection that was moved to the new index.
i.byConnection[last] = userConnIndex
connectionID := wc.GetConnectionID()
// Remove webconns from i.byChannelID
// This has O(n) complexity. We are trading off speed while removing
// a connection, to improve broadcasting a message.
for chID, webConns := range i.byChannelID {
// https://go.dev/wiki/SliceTricks#filtering-without-allocating
filtered := webConns[:0]
for _, conn := range webConns {
if conn.GetConnectionID() != connectionID {
filtered = append(filtered, conn)
}
}
for i := len(filtered); i < len(webConns); i++ {
webConns[i] = nil
}
i.byChannelID[chID] = filtered
}
delete(i.byConnection, wc)
delete(i.byConnectionId, wc.GetConnectionID())
delete(i.byConnectionId, connectionID)
}
func (i *hubConnectionIndex) InvalidateCMCacheForUser(userID string) error {
// We make this query first to fail fast in case of an error.
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), userID, false, false)
if err != nil {
return err
}
// Clear out all user entries which belong to channels.
for chID, webConns := range i.byChannelID {
// https://go.dev/wiki/SliceTricks#filtering-without-allocating
filtered := webConns[:0]
for _, conn := range webConns {
if conn.UserId != userID {
filtered = append(filtered, conn)
}
}
for i := len(filtered); i < len(webConns); i++ {
webConns[i] = nil
}
i.byChannelID[chID] = filtered
}
// re-populate the cache
for chID := range cm {
i.byChannelID[chID] = append(i.byChannelID[chID], i.ForUser(userID)...)
}
return nil
}
func (i *hubConnectionIndex) Has(wc *WebConn) bool {
@ -691,6 +795,16 @@ func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
return conns
}
// ForChannel returns all connections for a channelID.
func (i *hubConnectionIndex) ForChannel(channelID string) []*WebConn {
// Note: this is expensive because usually there will be
// more than 1 member for a channel, and broadcasting
// is a hot path, but worth it.
conns := make([]*WebConn, len(i.byChannelID[channelID]))
copy(conns, i.byChannelID[channelID])
return conns
}
// ForUserActiveCount returns the number of active connections for a userID
func (i *hubConnectionIndex) ForUserActiveCount(id string) int {
cnt := 0

View File

@ -15,13 +15,11 @@ import (
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/i18n"
platform_mocks "github.com/mattermost/mattermost/server/v8/channels/app/platform/mocks"
"github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks"
"github.com/mattermost/mattermost/server/v8/channels/testlib"
)
@ -53,7 +51,7 @@ func registerDummyWebConn(t *testing.T, th *TestHelper, addr net.Addr, session *
Locale: "en",
}
wc := th.Service.NewWebConn(cfg, th.Suite, &hookRunner{})
th.Service.HubRegister(wc)
require.NoError(t, th.Service.HubRegister(wc))
go wc.Pump()
return wc
}
@ -105,8 +103,8 @@ func TestHubStopRaceCondition(t *testing.T) {
go func() {
wc4 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc5 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
hub.Register(wc4)
hub.Register(wc5)
require.NoError(t, hub.Register(wc4))
require.NoError(t, hub.Register(wc5))
hub.UpdateActivity("userId", "sessionToken", 0)
@ -128,50 +126,9 @@ func TestHubStopRaceCondition(t *testing.T) {
}
func TestHubSessionRevokeRace(t *testing.T) {
th := SetupWithStoreMock(t)
th := Setup(t)
defer th.TearDown()
sess1 := &model.Session{
Id: "id1",
UserId: "user1",
DeviceId: "",
Token: "sesstoken",
ExpiresAt: model.GetMillis() + 300000,
LastActivityAt: 10000,
}
mockStore := th.Service.Store.(*mocks.Store)
mockUserStore := mocks.UserStore{}
mockUserStore.On("Count", mock.Anything).Return(int64(10), nil)
mockUserStore.On("GetUnreadCount", mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(int64(1), nil)
mockPostStore := mocks.PostStore{}
mockPostStore.On("GetMaxPostSize").Return(65535, nil)
mockSystemStore := mocks.SystemStore{}
mockSystemStore.On("GetByName", "UpgradedFromTE").Return(&model.System{Name: "UpgradedFromTE", Value: "false"}, nil)
mockSystemStore.On("GetByName", "InstallationDate").Return(&model.System{Name: "InstallationDate", Value: "10"}, nil)
mockSystemStore.On("GetByName", "FirstServerRunTimestamp").Return(&model.System{Name: "FirstServerRunTimestamp", Value: "10"}, nil)
mockSessionStore := mocks.SessionStore{}
mockSessionStore.On("UpdateLastActivityAt", "id1", mock.Anything).Return(nil)
mockSessionStore.On("Save", mock.AnythingOfType("*request.Context"), mock.AnythingOfType("*model.Session")).Return(sess1, nil)
mockSessionStore.On("Get", mock.AnythingOfType("*request.Context"), mock.Anything, "id1").Return(sess1, nil)
mockSessionStore.On("Remove", "id1").Return(nil)
mockStatusStore := mocks.StatusStore{}
mockStatusStore.On("Get", "user1").Return(&model.Status{UserId: "user1", Status: model.StatusOnline}, nil)
mockStatusStore.On("UpdateLastActivityAt", "user1", mock.Anything).Return(nil)
mockStatusStore.On("SaveOrUpdate", mock.AnythingOfType("*model.Status")).Return(nil)
mockOAuthStore := mocks.OAuthStore{}
mockStore.On("Session").Return(&mockSessionStore)
mockStore.On("OAuth").Return(&mockOAuthStore)
mockStore.On("Status").Return(&mockStatusStore)
mockStore.On("User").Return(&mockUserStore)
mockStore.On("Post").Return(&mockPostStore)
mockStore.On("System").Return(&mockSystemStore)
mockStore.On("GetDBSchemaVersion").Return(1, nil)
// This needs to be false for the condition to trigger
th.Service.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.ExtendSessionLengthWithActivity = false
@ -181,7 +138,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
defer s.Close()
session, err := th.Service.CreateSession(th.Context, &model.Session{
UserId: "testid",
UserId: model.NewId(),
})
require.NoError(t, err)
@ -193,7 +150,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
time.Sleep(2 * time.Second)
// We override the LastActivityAt which happens in NewWebConn.
// This is needed to call RevokeSessionById which triggers the race.
th.Service.AddSessionToCache(sess1)
th.Service.AddSessionToCache(session)
go func() {
for i := 0; i <= broadcastQueueSize; i++ {
@ -203,7 +160,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
}()
// This call should happen _after_ !wc.IsAuthenticated() and _before_wc.isMemberOfTeam().
// There's no guarantee this will happen. But that's out best bet to trigger this race.
// There's no guarantee this will happen. But that's our best bet to trigger this race.
wc1.InvalidateCache()
for i := 0; i < 10; i++ {
@ -220,92 +177,236 @@ func TestHubSessionRevokeRace(t *testing.T) {
}
func TestHubConnIndex(t *testing.T) {
th := Setup(t)
th := Setup(t).InitBasic()
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
// User1
wc1 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc1.SetConnectionID(model.NewId())
wc1.SetSession(&model.Session{})
// User2
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc2.SetConnectionID(model.NewId())
wc2.SetSession(&model.Session{})
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID(model.NewId())
wc3.SetSession(&model.Session{})
wc4 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc4.SetConnectionID(model.NewId())
wc4.SetSession(&model.Session{})
connIndex.Add(wc1)
connIndex.Add(wc2)
connIndex.Add(wc3)
connIndex.Add(wc4)
_, err := th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
ChannelId: th.BasicChannel.Id,
UserId: th.BasicUser.Id,
NotifyProps: model.GetDefaultChannelNotifyProps(),
SchemeGuest: th.BasicUser.IsGuest(),
SchemeUser: !th.BasicUser.IsGuest(),
})
require.NoError(t, err)
_, err = th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
ChannelId: th.BasicChannel.Id,
UserId: th.BasicUser2.Id,
NotifyProps: model.GetDefaultChannelNotifyProps(),
SchemeGuest: th.BasicUser2.IsGuest(),
SchemeUser: !th.BasicUser2.IsGuest(),
})
require.NoError(t, err)
t.Run("Basic", func(t *testing.T) {
assert.True(t, connIndex.Has(wc1))
assert.True(t, connIndex.Has(wc2))
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc3, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
assert.True(t, connIndex.Has(wc2))
assert.True(t, connIndex.Has(wc1))
assert.Len(t, connIndex.All(), 4)
// User1
wc1 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc1.SetConnectionID(model.NewId())
wc1.SetSession(&model.Session{})
// User2
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc2.SetConnectionID(model.NewId())
wc2.SetSession(&model.Session{})
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID(model.NewId())
wc3.SetSession(&model.Session{})
wc4 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc4.SetConnectionID(model.NewId())
wc4.SetSession(&model.Session{})
connIndex.Add(wc1)
connIndex.Add(wc2)
connIndex.Add(wc3)
connIndex.Add(wc4)
t.Run("Basic", func(t *testing.T) {
assert.True(t, connIndex.Has(wc1))
assert.True(t, connIndex.Has(wc2))
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc3, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
assert.True(t, connIndex.Has(wc2))
assert.True(t, connIndex.Has(wc1))
assert.Len(t, connIndex.All(), 4)
})
t.Run("RemoveMiddleUser2", func(t *testing.T) {
connIndex.Remove(wc3) // Remove from middle from user2
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
assert.True(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.True(t, connIndex.Has(wc4))
assert.Len(t, connIndex.All(), 3)
})
t.Run("RemoveUser1", func(t *testing.T) {
connIndex.Remove(wc1) // Remove sole connection from user1
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
assert.Len(t, connIndex.ForUser(wc1.UserId), 0)
assert.Len(t, connIndex.All(), 2)
assert.False(t, connIndex.Has(wc1))
assert.True(t, connIndex.Has(wc2))
})
t.Run("RemoveEndUser2", func(t *testing.T) {
connIndex.Remove(wc4) // Remove from end from user2
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
assert.True(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.False(t, connIndex.Has(wc4))
assert.Len(t, connIndex.All(), 1)
})
})
t.Run("RemoveMiddleUser2", func(t *testing.T) {
connIndex.Remove(wc3) // Remove from middle from user2
t.Run("ByConnectionId", func(t *testing.T) {
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
assert.True(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.True(t, connIndex.Has(wc4))
assert.Len(t, connIndex.All(), 3)
// User1
wc1ID := model.NewId()
wc1 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: th.BasicUser.Id,
}
wc1.SetConnectionID(wc1ID)
wc1.SetSession(&model.Session{})
// User2
wc2ID := model.NewId()
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: th.BasicUser2.Id,
}
wc2.SetConnectionID(wc2ID)
wc2.SetSession(&model.Session{})
wc3ID := model.NewId()
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID(wc3ID)
wc3.SetSession(&model.Session{})
t.Run("no connections", func(t *testing.T) {
assert.False(t, connIndex.Has(wc1))
assert.False(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.Empty(t, connIndex.byConnectionId)
})
t.Run("adding", func(t *testing.T) {
connIndex.Add(wc1)
connIndex.Add(wc3)
assert.Len(t, connIndex.byConnectionId, 2)
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
t.Run("removing", func(t *testing.T) {
connIndex.Remove(wc3)
assert.Len(t, connIndex.byConnectionId, 1)
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
})
t.Run("RemoveUser1", func(t *testing.T) {
connIndex.Remove(wc1) // Remove sole connection from user1
t.Run("ByChannelId", func(t *testing.T) {
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
assert.Len(t, connIndex.ForUser(wc1.UserId), 0)
assert.Len(t, connIndex.All(), 2)
assert.False(t, connIndex.Has(wc1))
assert.True(t, connIndex.Has(wc2))
})
// User1
wc1ID := model.NewId()
wc1 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: th.BasicUser.Id,
}
wc1.SetConnectionID(wc1ID)
wc1.SetSession(&model.Session{})
t.Run("RemoveEndUser2", func(t *testing.T) {
connIndex.Remove(wc4) // Remove from end from user2
// User2
wc2ID := model.NewId()
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: th.BasicUser2.Id,
}
wc2.SetConnectionID(wc2ID)
wc2.SetSession(&model.Session{})
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2})
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
assert.True(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.False(t, connIndex.Has(wc4))
assert.Len(t, connIndex.All(), 1)
wc3ID := model.NewId()
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID(wc3ID)
wc3.SetSession(&model.Session{})
connIndex.Add(wc1)
connIndex.Add(wc2)
connIndex.Add(wc3)
t.Run("ForChannel", func(t *testing.T) {
require.Len(t, connIndex.byChannelID, 1)
require.Equal(t, []*WebConn{wc1, wc2, wc3}, connIndex.ForChannel(th.BasicChannel.Id))
require.Len(t, connIndex.ForChannel("notexist"), 0)
})
ch := th.CreateChannel(th.BasicTeam)
_, err = th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
ChannelId: ch.Id,
UserId: th.BasicUser2.Id,
NotifyProps: model.GetDefaultChannelNotifyProps(),
SchemeGuest: th.BasicUser2.IsGuest(),
SchemeUser: !th.BasicUser2.IsGuest(),
})
require.NoError(t, err)
t.Run("InvalidateCMCacheForUser", func(t *testing.T) {
require.NoError(t, connIndex.InvalidateCMCacheForUser(th.BasicUser2.Id))
require.Len(t, connIndex.byChannelID, 2)
require.Len(t, connIndex.ForChannel(th.BasicChannel.Id), 3)
require.Len(t, connIndex.ForChannel(ch.Id), 2)
})
t.Run("Remove", func(t *testing.T) {
connIndex.Remove(wc3)
require.Len(t, connIndex.byChannelID, 2)
require.Len(t, connIndex.ForChannel(th.BasicChannel.Id), 2)
})
})
}
@ -313,7 +414,7 @@ func TestHubConnIndexIncorrectRemoval(t *testing.T) {
th := Setup(t)
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
// User2
wc2 := &WebConn{
@ -356,73 +457,11 @@ func TestHubConnIndexIncorrectRemoval(t *testing.T) {
}
}
func TestHubConnIndexByConnectionId(t *testing.T) {
th := Setup(t)
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
// User1
wc1ID := model.NewId()
wc1 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc1.SetConnectionID(wc1ID)
wc1.SetSession(&model.Session{})
// User2
wc2ID := model.NewId()
wc2 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: model.NewId(),
}
wc2.SetConnectionID(wc2ID)
wc2.SetSession(&model.Session{})
wc3ID := model.NewId()
wc3 := &WebConn{
Platform: th.Service,
Suite: th.Suite,
UserId: wc2.UserId,
}
wc3.SetConnectionID(wc3ID)
wc3.SetSession(&model.Session{})
t.Run("no connections", func(t *testing.T) {
assert.False(t, connIndex.Has(wc1))
assert.False(t, connIndex.Has(wc2))
assert.False(t, connIndex.Has(wc3))
assert.Empty(t, connIndex.byConnectionId)
})
t.Run("adding", func(t *testing.T) {
connIndex.Add(wc1)
connIndex.Add(wc3)
assert.Len(t, connIndex.byConnectionId, 2)
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
t.Run("removing", func(t *testing.T) {
connIndex.Remove(wc3)
assert.Len(t, connIndex.byConnectionId, 1)
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
})
}
func TestHubConnIndexInactive(t *testing.T) {
th := Setup(t)
defer th.TearDown()
connIndex := newHubConnectionIndex(2 * time.Second)
connIndex := newHubConnectionIndex(2*time.Second, th.Service.Store, th.Service.logger)
// User1
wc1 := &WebConn{
@ -582,7 +621,7 @@ func TestHubWebConnCount(t *testing.T) {
func BenchmarkHubConnIndex(b *testing.B) {
th := Setup(b).InitBasic()
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
// User1
wc1 := &WebConn{
@ -627,7 +666,7 @@ func TestHubConnIndexRemoveMemLeak(t *testing.T) {
th := Setup(t)
defer th.TearDown()
connIndex := newHubConnectionIndex(1 * time.Second)
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
wc := &WebConn{
Platform: th.Service,

View File

@ -49,6 +49,7 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque
session, err := conn.Suite.GetSession(token)
if err != nil {
conn.Platform.Log().Warn("Error while getting session token", mlog.Err(err))
conn.WebSocket.Close()
return
}
@ -56,7 +57,12 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque
conn.SetSessionToken(session.Token)
conn.UserId = session.UserId
conn.Platform.HubRegister(conn)
nErr := conn.Platform.HubRegister(conn)
if nErr != nil {
conn.Platform.Log().Error("Error while registering to hub", mlog.String("user_id", conn.UserId), mlog.Err(nErr))
conn.WebSocket.Close()
return
}
conn.Platform.Go(func() {
conn.Platform.SetStatusOnline(session.UserId, false)

View File

@ -15,6 +15,9 @@ import (
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
)
// TestWebConnShouldSendEvent is not exhaustive because some of the checks
// happen inside web_hub.go before the event is actually broadcasted, and checked
// via ShouldSendEvent.
func TestWebConnShouldSendEvent(t *testing.T) {
os.Setenv("MM_FEATUREFLAGS_WEBSOCKETEVENTSCOPE", "true")
defer os.Unsetenv("MM_FEATUREFLAGS_WEBSOCKETEVENTSCOPE")
@ -157,14 +160,6 @@ func TestWebConnShouldSendEvent(t *testing.T) {
})
}
t.Run("should send to basic user in basic channel", func(t *testing.T) {
event = event.SetBroadcast(&model.WebsocketBroadcast{ChannelId: th.BasicChannel.Id})
assert.True(t, basicUserWc.ShouldSendEvent(event), "expected user 1")
assert.False(t, basicUser2Wc.ShouldSendEvent(event), "did not expect user 2")
assert.False(t, adminUserWc.ShouldSendEvent(event), "did not expect admin")
})
t.Run("should not send typing event unless in scope", func(t *testing.T) {
event2 := model.NewWebSocketEvent(model.WebsocketEventTyping, "", th.BasicChannel.Id, "", nil, "")
// Basic, unset case
@ -222,14 +217,6 @@ func TestWebConnShouldSendEvent(t *testing.T) {
assert.False(t, basicUserWc.ShouldSendEvent(event2))
})
t.Run("should send to basic user and admin in channel2", func(t *testing.T) {
event = event.SetBroadcast(&model.WebsocketBroadcast{ChannelId: channel2.Id})
assert.True(t, basicUserWc.ShouldSendEvent(event), "expected user 1")
assert.False(t, basicUser2Wc.ShouldSendEvent(event), "did not expect user 2")
assert.True(t, adminUserWc.ShouldSendEvent(event), "expected admin")
})
t.Run("channel member cache invalidated after user added to channel", func(t *testing.T) {
th.AddUserToChannel(th.BasicUser2, channel2)
basicUser2Wc.InvalidateCache()

View File

@ -16,16 +16,6 @@ func (a *App) GetHubForUserId(userID string) *platform.Hub {
return a.Srv().Platform().GetHubForUserId(userID)
}
// HubRegister registers a connection to a hub.
func (a *App) HubRegister(webConn *platform.WebConn) {
a.Srv().Platform().HubRegister(webConn)
}
// HubUnregister unregisters a connection from a hub.
func (a *App) HubUnregister(webConn *platform.WebConn) {
a.Srv().Platform().HubUnregister(webConn)
}
func (a *App) Publish(message *model.WebSocketEvent) {
a.Srv().Platform().Publish(message)
}