mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-61130: Use a channelMember map at web_hub level (#28810)
Tests at very high scale indicates that the iteration of all connections during websocket broadcast starts to become a bottleneck. To optimize this, we move the channelMember cache from inside web_conn.go to the hubConnectionIndex. This involves adding a new map keyed by the channelID and containing all webConns where the user is a member of that channel. Subsequently, a new method needed to be added to invalidate the cache which previously used to happen in web_conn. And as a last step, we remove the cache from web_conn to reduce SQL queries to the DB. https://mattermost.atlassian.net/browse/MM-61130 ```release-note NONE ```
This commit is contained in:
parent
37d97e8024
commit
bd8774bdce
@ -16,6 +16,7 @@ import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -27,6 +28,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/testlib"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/utils/testutils"
|
||||
)
|
||||
@ -2937,6 +2939,155 @@ func TestPermanentDeletePost(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWebHubMembership(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
u1 := th.CreateUser()
|
||||
th.LinkUserToTeam(u1, th.BasicTeam)
|
||||
th.AddUserToChannel(u1, th.BasicChannel)
|
||||
|
||||
ch2 := th.CreatePrivateChannel()
|
||||
u2 := th.CreateUser()
|
||||
th.LinkUserToTeam(u2, th.BasicTeam)
|
||||
th.AddUserToChannel(u2, ch2)
|
||||
|
||||
quitChan := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(3)
|
||||
for _, obj := range []struct {
|
||||
testName string
|
||||
user *model.User
|
||||
}{
|
||||
{
|
||||
testName: "basicUser",
|
||||
user: th.BasicUser,
|
||||
},
|
||||
{
|
||||
testName: "u1",
|
||||
user: u1,
|
||||
},
|
||||
{
|
||||
testName: "u2",
|
||||
user: u2,
|
||||
},
|
||||
} {
|
||||
cli := th.CreateClient()
|
||||
_, _, err := cli.Login(context.Background(), obj.user.Username, obj.user.Password)
|
||||
require.NoError(t, err)
|
||||
|
||||
wsClient, err := th.CreateWebSocketClientWithClient(cli)
|
||||
require.NoError(t, err)
|
||||
defer wsClient.Close()
|
||||
|
||||
wsClient.Listen()
|
||||
|
||||
go func(testName string) {
|
||||
defer wg.Done()
|
||||
var cnt int
|
||||
for {
|
||||
select {
|
||||
case event := <-wsClient.EventChannel:
|
||||
if event.EventType() == model.WebsocketEventPosted {
|
||||
var post model.Post
|
||||
err := json.Unmarshal([]byte(event.GetData()["post"].(string)), &post)
|
||||
require.NoError(t, err)
|
||||
|
||||
cnt++
|
||||
// Cases:
|
||||
// Post to basicChannel should go to u1 and basicUser.
|
||||
// Add u1 to ch2.
|
||||
// Post to ch2 should go to u1, u2 and basicUser.
|
||||
// Remove u1 from ch2.
|
||||
// Post to ch2 should go to u2 and basicUser.
|
||||
switch testName {
|
||||
case "basicUser":
|
||||
if cnt == 1 {
|
||||
assert.Equal(t, th.BasicChannel.Id, post.ChannelId)
|
||||
} else if cnt == 2 {
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else if cnt == 3 {
|
||||
// After removing, there will be a "removed from channel post"
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else if cnt == 4 {
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else {
|
||||
assert.Fail(t, "more than 4 messages arrived for basicUser")
|
||||
}
|
||||
case "u1":
|
||||
// First msg should be from basicChannel
|
||||
if cnt == 1 {
|
||||
assert.Equal(t, th.BasicChannel.Id, post.ChannelId)
|
||||
} else if cnt == 2 {
|
||||
// second should be from ch2
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else {
|
||||
assert.Fail(t, "more than 2 messages arrived for u1")
|
||||
}
|
||||
case "u2":
|
||||
if cnt == 1 {
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else if cnt == 2 {
|
||||
// After removing, there will be a "removed from channel post"
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else if cnt == 3 {
|
||||
assert.Equal(t, ch2.Id, post.ChannelId)
|
||||
} else {
|
||||
assert.Fail(t, "more than 3 messages arrived for u2")
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-quitChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}(obj.testName)
|
||||
}
|
||||
|
||||
// Will send to basic channel
|
||||
th.CreatePost()
|
||||
// Add u1 to ch2
|
||||
th.AddUserToChannel(u1, ch2)
|
||||
// Send post to ch2
|
||||
th.CreatePostWithClient(th.Client, ch2)
|
||||
// Remove u1 from ch2
|
||||
th.RemoveUserFromChannel(u1, ch2)
|
||||
// Send post to ch2
|
||||
th.CreatePostWithClient(th.Client, ch2)
|
||||
|
||||
// It is possible to create a signalling mechanism from the goroutines
|
||||
// after all events are received, but we also want to verify that no additional
|
||||
// events are being sent.
|
||||
time.Sleep(2 * time.Second)
|
||||
close(quitChan)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWebHubCloseConnOnDBFail(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer func() {
|
||||
th.TearDown()
|
||||
// Asserting that the error message is present in the log
|
||||
testlib.AssertLog(t, th.LogBuffer, mlog.LvlError.Name, "Error while registering to hub")
|
||||
_, err := th.Server.Store().GetInternalMasterDB().Exec(`ALTER TABLE dummy RENAME to ChannelMembers`)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
cli := th.CreateClient()
|
||||
_, _, err := cli.Login(context.Background(), th.BasicUser.Username, th.BasicUser.Password)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = th.Server.Store().GetInternalMasterDB().Exec(`ALTER TABLE ChannelMembers RENAME to dummy`)
|
||||
require.NoError(t, err)
|
||||
|
||||
wsClient, err := th.CreateWebSocketClientWithClient(cli)
|
||||
require.NoError(t, err)
|
||||
defer wsClient.Close()
|
||||
|
||||
require.NoError(t, th.TestLogger.Flush())
|
||||
}
|
||||
|
||||
func TestDeletePostEvent(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
@ -70,7 +70,7 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
cfg, err = c.App.Srv().Platform().PopulateWebConnConfig(c.AppContext.Session(), cfg, r.URL.Query().Get(sequenceNumberParam))
|
||||
if err != nil {
|
||||
c.Logger.Warn("Error while populating webconn config", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
|
||||
c.Logger.Error("Error while populating webconn config", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
|
||||
ws.Close()
|
||||
return
|
||||
}
|
||||
@ -78,7 +78,12 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
wc := c.App.Srv().Platform().NewWebConn(cfg, c.App, c.App.Srv().Channels())
|
||||
if c.AppContext.Session().UserId != "" {
|
||||
c.App.Srv().Platform().HubRegister(wc)
|
||||
err = c.App.Srv().Platform().HubRegister(wc)
|
||||
if err != nil {
|
||||
c.Logger.Error("Error while registering to hub", mlog.String("id", r.URL.Query().Get(connectionIDParam)), mlog.Err(err))
|
||||
ws.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
wc.Pump()
|
||||
|
@ -249,10 +249,6 @@ type AppIface interface {
|
||||
GetUserStatusesByIds(userIDs []string) ([]*model.Status, *model.AppError)
|
||||
// HasRemote returns whether a given channelID is present in the channel remotes or not.
|
||||
HasRemote(channelID string, remoteID string) (bool, error)
|
||||
// HubRegister registers a connection to a hub.
|
||||
HubRegister(webConn *platform.WebConn)
|
||||
// HubUnregister unregisters a connection from a hub.
|
||||
HubUnregister(webConn *platform.WebConn)
|
||||
// InstallPlugin unpacks and installs a plugin but does not enable or activate it unless the the
|
||||
// plugin was already enabled.
|
||||
InstallPlugin(pluginFile io.ReadSeeker, replace bool) (*model.Manifest, *model.AppError)
|
||||
|
@ -523,7 +523,7 @@ func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectio
|
||||
Locale: "en",
|
||||
ConnectionID: connectionID,
|
||||
}, th.App, th.App.Channels())
|
||||
th.App.Srv().Platform().HubRegister(webConn)
|
||||
require.NoError(t, th.App.Srv().Platform().HubRegister(webConn))
|
||||
|
||||
// Start reading from it
|
||||
go webConn.Pump()
|
||||
|
@ -11982,36 +11982,6 @@ func (a *OpenTracingAppLayer) HasSharedChannel(channelID string) (bool, error) {
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) HubRegister(webConn *platform.WebConn) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HubRegister")
|
||||
|
||||
a.ctx = newCtx
|
||||
a.app.Srv().Store().SetContext(newCtx)
|
||||
defer func() {
|
||||
a.app.Srv().Store().SetContext(origCtx)
|
||||
a.ctx = origCtx
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
a.app.HubRegister(webConn)
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) HubUnregister(webConn *platform.WebConn) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.HubUnregister")
|
||||
|
||||
a.ctx = newCtx
|
||||
a.app.Srv().Store().SetContext(newCtx)
|
||||
defer func() {
|
||||
a.app.Srv().Store().SetContext(origCtx)
|
||||
a.ctx = origCtx
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
a.app.HubUnregister(webConn)
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) IPFiltering() einterfaces.IPFilteringInterface {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.IPFiltering")
|
||||
|
@ -26,7 +26,6 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -95,10 +94,8 @@ type WebConn struct {
|
||||
UserId string
|
||||
PostedAck bool
|
||||
|
||||
allChannelMembers map[string]string
|
||||
lastAllChannelMembersTime int64
|
||||
lastUserActivityAt int64
|
||||
send chan model.WebSocketMessage
|
||||
lastUserActivityAt int64
|
||||
send chan model.WebSocketMessage
|
||||
// deadQueue behaves like a queue of a finite size
|
||||
// which is used to store all messages that are sent via the websocket.
|
||||
// It basically acts as the user-space socket buffer, and is used
|
||||
@ -218,7 +215,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
|
||||
if tcpConn != nil {
|
||||
err := tcpConn.SetNoDelay(false)
|
||||
if err != nil {
|
||||
mlog.Warn("Error in setting NoDelay socket opts", mlog.Err(err))
|
||||
ps.logger.Warn("Error in setting NoDelay socket opts", mlog.Err(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -321,6 +318,9 @@ func (wc *WebConn) SetConnectionID(id string) {
|
||||
|
||||
// GetConnectionID returns the connection id of the connection.
|
||||
func (wc *WebConn) GetConnectionID() string {
|
||||
if wc.connectionID.Load() == nil {
|
||||
return ""
|
||||
}
|
||||
return wc.connectionID.Load().(string)
|
||||
}
|
||||
|
||||
@ -566,7 +566,7 @@ func (wc *WebConn) writePump() {
|
||||
err = enc.Encode(msg)
|
||||
}
|
||||
if err != nil {
|
||||
mlog.Warn("Error in encoding websocket message", mlog.Err(err))
|
||||
wc.Platform.logger.Warn("Error in encoding websocket message", mlog.Err(err))
|
||||
continue
|
||||
}
|
||||
|
||||
@ -581,7 +581,7 @@ func (wc *WebConn) writePump() {
|
||||
logData = append(logData, mlog.String("channel_id", evt.GetBroadcast().ChannelId))
|
||||
}
|
||||
|
||||
mlog.Warn("websocket.full", logData...)
|
||||
wc.Platform.logger.Warn("websocket.full", logData...)
|
||||
wc.lastLogTimeFull = time.Now()
|
||||
}
|
||||
|
||||
@ -608,7 +608,7 @@ func (wc *WebConn) writePump() {
|
||||
|
||||
case <-authTicker.C:
|
||||
if wc.GetSessionToken() == "" {
|
||||
mlog.Debug("websocket.authTicker: did not authenticate", mlog.Stringer("ip_address", wc.WebSocket.RemoteAddr()))
|
||||
wc.Platform.logger.Debug("websocket.authTicker: did not authenticate", mlog.Stringer("ip_address", wc.WebSocket.RemoteAddr()))
|
||||
return
|
||||
}
|
||||
authTicker.Stop()
|
||||
@ -629,7 +629,7 @@ func (wc *WebConn) writeMessage(msg *model.WebSocketEvent) error {
|
||||
var buf bytes.Buffer
|
||||
err := msg.Encode(json.NewEncoder(&buf), &buf)
|
||||
if err != nil {
|
||||
mlog.Warn("Error in encoding websocket message", mlog.Err(err))
|
||||
wc.Platform.logger.Warn("Error in encoding websocket message", mlog.Err(err))
|
||||
return nil
|
||||
}
|
||||
wc.Sequence++
|
||||
@ -734,8 +734,6 @@ func (wc *WebConn) drainDeadQueue(index int) error {
|
||||
|
||||
// InvalidateCache resets all internal data of the WebConn.
|
||||
func (wc *WebConn) InvalidateCache() {
|
||||
wc.allChannelMembers = nil
|
||||
wc.lastAllChannelMembersTime = 0
|
||||
wc.SetSession(nil)
|
||||
wc.SetSessionExpiresAt(0)
|
||||
}
|
||||
@ -751,9 +749,9 @@ func (wc *WebConn) IsAuthenticated() bool {
|
||||
session, err := wc.Suite.GetSession(wc.GetSessionToken())
|
||||
if err != nil {
|
||||
if err.StatusCode >= http.StatusBadRequest && err.StatusCode < http.StatusInternalServerError {
|
||||
mlog.Debug("Invalid session.", mlog.Err(err))
|
||||
wc.Platform.logger.Debug("Invalid session.", mlog.Err(err))
|
||||
} else {
|
||||
mlog.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
|
||||
wc.Platform.logger.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
|
||||
}
|
||||
|
||||
wc.SetSessionToken("")
|
||||
@ -789,7 +787,7 @@ func (wc *WebConn) ShouldSendEventToGuest(msg *model.WebSocketEvent) bool {
|
||||
case model.WebsocketEventUserUpdated:
|
||||
user, ok := msg.GetData()["user"].(*model.User)
|
||||
if !ok {
|
||||
mlog.Debug("webhub.shouldSendEvent: user not found in message", mlog.Any("user", msg.GetData()["user"]))
|
||||
wc.Platform.logger.Debug("webhub.shouldSendEvent: user not found in message", mlog.Any("user", msg.GetData()["user"]))
|
||||
return false
|
||||
}
|
||||
userID = user.Id
|
||||
@ -828,7 +826,7 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
|
||||
model.WebsocketEventStatusChange,
|
||||
model.WebsocketEventMultipleChannelsViewed:
|
||||
if wc.Active.Load() && time.Since(wc.lastLogTimeSlow) > websocketSuppressWarnThreshold {
|
||||
mlog.Warn(
|
||||
wc.Platform.logger.Warn(
|
||||
"websocket.slow: dropping message",
|
||||
mlog.String("user_id", wc.UserId),
|
||||
mlog.String("conn_id", wc.GetConnectionID()),
|
||||
@ -893,8 +891,8 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
|
||||
|
||||
// Only report events to users who are in the channel for the event
|
||||
if chID := msg.GetBroadcast().ChannelId; chID != "" {
|
||||
// For typing events, we don't send them to users who don't have
|
||||
// that channel or thread opened.
|
||||
// For typing/reaction_added/reaction_removed events, we don't send them to users
|
||||
// who don't have that channel or thread opened.
|
||||
if wc.Platform.Config().FeatureFlags.WebSocketEventScope &&
|
||||
slices.Contains([]model.WebsocketEventType{
|
||||
model.WebsocketEventTyping,
|
||||
@ -904,30 +902,9 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if model.GetMillis()-wc.lastAllChannelMembersTime > webConnMemberCacheTime {
|
||||
wc.allChannelMembers = nil
|
||||
wc.lastAllChannelMembersTime = 0
|
||||
}
|
||||
|
||||
if wc.allChannelMembers == nil {
|
||||
result, err := wc.Platform.Store.Channel().GetAllChannelMembersForUser(
|
||||
sqlstore.RequestContextWithMaster(request.EmptyContext(wc.Platform.logger)),
|
||||
wc.UserId,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
mlog.Error("webhub.shouldSendEvent.", mlog.Err(err))
|
||||
return false
|
||||
}
|
||||
wc.allChannelMembers = result
|
||||
wc.lastAllChannelMembersTime = model.GetMillis()
|
||||
}
|
||||
|
||||
if _, ok := wc.allChannelMembers[chID]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
// We don't need to do any further checks because this is already scoped
|
||||
// to channel members from web_hub.
|
||||
return true
|
||||
}
|
||||
|
||||
// Only report events to users who are in the team for the event
|
||||
@ -960,9 +937,9 @@ func (wc *WebConn) isMemberOfTeam(teamID string) bool {
|
||||
session, err := wc.Suite.GetSession(wc.GetSessionToken())
|
||||
if err != nil {
|
||||
if err.StatusCode >= http.StatusBadRequest && err.StatusCode < http.StatusInternalServerError {
|
||||
mlog.Debug("Invalid session.", mlog.Err(err))
|
||||
wc.Platform.logger.Debug("Invalid session.", mlog.Err(err))
|
||||
} else {
|
||||
mlog.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
|
||||
wc.Platform.logger.Error("Could not get session", mlog.String("session_token", wc.GetSessionToken()), mlog.Err(err))
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -976,12 +953,12 @@ func (wc *WebConn) isMemberOfTeam(teamID string) bool {
|
||||
func (wc *WebConn) logSocketErr(source string, err error) {
|
||||
// browsers will appear as CloseNoStatusReceived
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
|
||||
mlog.Debug(source+": client side closed socket",
|
||||
wc.Platform.logger.Debug(source+": client side closed socket",
|
||||
mlog.String("user_id", wc.UserId),
|
||||
mlog.String("conn_id", wc.GetConnectionID()),
|
||||
mlog.String("origin_client", wc.originClient))
|
||||
} else {
|
||||
mlog.Debug(source+": closing websocket",
|
||||
wc.Platform.logger.Debug(source+": closing websocket",
|
||||
mlog.String("user_id", wc.UserId),
|
||||
mlog.String("conn_id", wc.GetConnectionID()),
|
||||
mlog.String("origin_client", wc.originClient),
|
||||
|
@ -4,6 +4,7 @@
|
||||
package platform
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
@ -14,6 +15,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -45,6 +47,11 @@ type webConnSessionMessage struct {
|
||||
isRegistered chan bool
|
||||
}
|
||||
|
||||
type webConnRegisterMessage struct {
|
||||
conn *WebConn
|
||||
err chan error
|
||||
}
|
||||
|
||||
type webConnCheckMessage struct {
|
||||
userID string
|
||||
connectionID string
|
||||
@ -65,7 +72,7 @@ type Hub struct {
|
||||
connectionCount int64
|
||||
platform *PlatformService
|
||||
connectionIndex int
|
||||
register chan *WebConn
|
||||
register chan *webConnRegisterMessage
|
||||
unregister chan *WebConn
|
||||
broadcast chan *model.WebSocketEvent
|
||||
stop chan struct{}
|
||||
@ -84,7 +91,7 @@ type Hub struct {
|
||||
func newWebHub(ps *PlatformService) *Hub {
|
||||
return &Hub{
|
||||
platform: ps,
|
||||
register: make(chan *WebConn),
|
||||
register: make(chan *webConnRegisterMessage),
|
||||
unregister: make(chan *WebConn),
|
||||
broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize),
|
||||
stop: make(chan struct{}),
|
||||
@ -150,14 +157,15 @@ func (ps *PlatformService) GetHubForUserId(userID string) *Hub {
|
||||
}
|
||||
|
||||
// HubRegister registers a connection to a hub.
|
||||
func (ps *PlatformService) HubRegister(webConn *WebConn) {
|
||||
func (ps *PlatformService) HubRegister(webConn *WebConn) error {
|
||||
hub := ps.GetHubForUserId(webConn.UserId)
|
||||
if hub != nil {
|
||||
if metrics := ps.metricsIFace; metrics != nil {
|
||||
metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
|
||||
}
|
||||
hub.Register(webConn)
|
||||
return hub.Register(webConn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HubUnregister unregisters a connection from a hub.
|
||||
@ -262,11 +270,17 @@ func (ps *PlatformService) WebConnCountForUser(userID string) int {
|
||||
}
|
||||
|
||||
// Register registers a connection to the hub.
|
||||
func (h *Hub) Register(webConn *WebConn) {
|
||||
func (h *Hub) Register(webConn *WebConn) error {
|
||||
wr := &webConnRegisterMessage{
|
||||
conn: webConn,
|
||||
err: make(chan error),
|
||||
}
|
||||
select {
|
||||
case h.register <- webConn:
|
||||
case h.register <- wr:
|
||||
return <-wr.err
|
||||
case <-h.stop:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister unregisters a connection from the hub.
|
||||
@ -389,7 +403,10 @@ func (h *Hub) Start() {
|
||||
ticker := time.NewTicker(inactiveConnReaperInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
connIndex := newHubConnectionIndex(inactiveConnReaperInterval)
|
||||
connIndex := newHubConnectionIndex(inactiveConnReaperInterval,
|
||||
h.platform.Store,
|
||||
h.platform.logger,
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -423,22 +440,27 @@ func (h *Hub) Start() {
|
||||
req.result <- connIndex.ForUserActiveCount(req.userID)
|
||||
case <-ticker.C:
|
||||
connIndex.RemoveInactiveConnections()
|
||||
case webConn := <-h.register:
|
||||
case webConnReg := <-h.register:
|
||||
// Mark the current one as active.
|
||||
// There is no need to check if it was inactive or not,
|
||||
// we will anyways need to make it active.
|
||||
webConn.Active.Store(true)
|
||||
webConnReg.conn.Active.Store(true)
|
||||
|
||||
connIndex.Add(webConn)
|
||||
err := connIndex.Add(webConnReg.conn)
|
||||
if err != nil {
|
||||
webConnReg.err <- err
|
||||
continue
|
||||
}
|
||||
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
|
||||
|
||||
if webConn.IsAuthenticated() && webConn.reuseCount == 0 {
|
||||
if webConnReg.conn.IsAuthenticated() && webConnReg.conn.reuseCount == 0 {
|
||||
// The hello message should only be sent when the reuseCount is 0.
|
||||
// i.e in server restart, or long timeout, or fresh connection case.
|
||||
// In case of seq number not found in dead queue, it is handled by
|
||||
// the webconn write pump.
|
||||
webConn.send <- webConn.createHelloMessage()
|
||||
webConnReg.conn.send <- webConnReg.conn.createHelloMessage()
|
||||
}
|
||||
webConnReg.err <- nil
|
||||
case webConn := <-h.unregister:
|
||||
// If already removed (via queue full), then removing again becomes a noop.
|
||||
// But if not removed, mark inactive.
|
||||
@ -497,6 +519,13 @@ func (h *Hub) Start() {
|
||||
for _, webConn := range connIndex.ForUser(userID) {
|
||||
webConn.InvalidateCache()
|
||||
}
|
||||
err := connIndex.InvalidateCMCacheForUser(userID)
|
||||
if err != nil {
|
||||
h.platform.Log().Error("Error while invalidating channel member cache", mlog.String("user_id", userID), mlog.Err(err))
|
||||
for _, webConn := range connIndex.ForUser(userID) {
|
||||
closeAndRemoveConn(connIndex, webConn)
|
||||
}
|
||||
}
|
||||
case activity := <-h.activity:
|
||||
for _, webConn := range connIndex.ForUser(activity.userID) {
|
||||
if !webConn.Active.Load() {
|
||||
@ -519,8 +548,7 @@ func (h *Hub) Start() {
|
||||
mlog.String("user_id", directMsg.conn.UserId),
|
||||
mlog.String("conn_id", directMsg.conn.GetConnectionID()))
|
||||
}
|
||||
close(directMsg.conn.send)
|
||||
connIndex.Remove(directMsg.conn)
|
||||
closeAndRemoveConn(connIndex, directMsg.conn)
|
||||
}
|
||||
case msg := <-h.broadcast:
|
||||
if metrics := h.platform.metricsIFace; metrics != nil {
|
||||
@ -546,27 +574,29 @@ func (h *Hub) Start() {
|
||||
mlog.String("user_id", webConn.UserId),
|
||||
mlog.String("conn_id", webConn.GetConnectionID()))
|
||||
}
|
||||
close(webConn.send)
|
||||
connIndex.Remove(webConn)
|
||||
closeAndRemoveConn(connIndex, webConn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var targetConns []*WebConn
|
||||
if connID := msg.GetBroadcast().ConnectionId; connID != "" {
|
||||
if webConn := connIndex.ForConnection(connID); webConn != nil {
|
||||
broadcast(webConn)
|
||||
continue
|
||||
targetConns = append(targetConns, webConn)
|
||||
}
|
||||
} else if msg.GetBroadcast().UserId != "" {
|
||||
candidates := connIndex.ForUser(msg.GetBroadcast().UserId)
|
||||
for _, webConn := range candidates {
|
||||
} else if userID := msg.GetBroadcast().UserId; userID != "" {
|
||||
targetConns = connIndex.ForUser(userID)
|
||||
} else if channelID := msg.GetBroadcast().ChannelId; channelID != "" {
|
||||
targetConns = connIndex.ForChannel(channelID)
|
||||
}
|
||||
if targetConns != nil {
|
||||
for _, webConn := range targetConns {
|
||||
broadcast(webConn)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
candidates := connIndex.All()
|
||||
for webConn := range candidates {
|
||||
for webConn := range connIndex.All() {
|
||||
broadcast(webConn)
|
||||
}
|
||||
case <-h.stop:
|
||||
@ -616,14 +646,24 @@ func areAllInactive(conns []*WebConn) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// closeAndRemoveConn closes the send channel which will close the
|
||||
// websocket connection, and then it removes the webConn from the conn index.
|
||||
func closeAndRemoveConn(connIndex *hubConnectionIndex, conn *WebConn) {
|
||||
close(conn.send)
|
||||
connIndex.Remove(conn)
|
||||
}
|
||||
|
||||
// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
|
||||
// It requires 3 functionalities which need to be very fast:
|
||||
// It requires 4 functionalities which need to be very fast:
|
||||
// - check if a connection exists or not.
|
||||
// - get all connections for a given userID.
|
||||
// - get all connections for a given channelID.
|
||||
// - get all connections.
|
||||
type hubConnectionIndex struct {
|
||||
// byUserId stores the list of connections for a given userID
|
||||
byUserId map[string][]*WebConn
|
||||
// byChannelID stores the list of connections for a given channelID.
|
||||
byChannelID map[string][]*WebConn
|
||||
// byConnection serves the dual purpose of storing the index of the webconn
|
||||
// in the value of byUserId map, and also to get all connections.
|
||||
byConnection map[*WebConn]int
|
||||
@ -631,21 +671,39 @@ type hubConnectionIndex struct {
|
||||
// staleThreshold is the limit beyond which inactive connections
|
||||
// will be deleted.
|
||||
staleThreshold time.Duration
|
||||
|
||||
store store.Store
|
||||
logger mlog.LoggerIFace
|
||||
}
|
||||
|
||||
func newHubConnectionIndex(interval time.Duration) *hubConnectionIndex {
|
||||
func newHubConnectionIndex(interval time.Duration,
|
||||
store store.Store,
|
||||
logger mlog.LoggerIFace,
|
||||
) *hubConnectionIndex {
|
||||
return &hubConnectionIndex{
|
||||
byUserId: make(map[string][]*WebConn),
|
||||
byChannelID: make(map[string][]*WebConn),
|
||||
byConnection: make(map[*WebConn]int),
|
||||
byConnectionId: make(map[string]*WebConn),
|
||||
staleThreshold: interval,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *hubConnectionIndex) Add(wc *WebConn) {
|
||||
func (i *hubConnectionIndex) Add(wc *WebConn) error {
|
||||
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), wc.UserId, false, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getChannelMembersForUser: %v", err)
|
||||
}
|
||||
for chID := range cm {
|
||||
i.byChannelID[chID] = append(i.byChannelID[chID], wc)
|
||||
}
|
||||
|
||||
i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc)
|
||||
i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1
|
||||
i.byConnectionId[wc.GetConnectionID()] = wc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *hubConnectionIndex) Remove(wc *WebConn) {
|
||||
@ -654,21 +712,67 @@ func (i *hubConnectionIndex) Remove(wc *WebConn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the wc from i.byUserId
|
||||
// get the conn slice.
|
||||
userConnections := i.byUserId[wc.UserId]
|
||||
// get the last connection.
|
||||
last := userConnections[len(userConnections)-1]
|
||||
// set the slot that we are trying to remove to be the last connection.
|
||||
// https://go.dev/wiki/SliceTricks#delete-without-preserving-order
|
||||
userConnections[userConnIndex] = last
|
||||
// remove the last connection pointer from slice.
|
||||
userConnections[len(userConnections)-1] = nil
|
||||
// remove the last connection from the slice.
|
||||
i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1]
|
||||
// set the index of the connection that was moved to the new index.
|
||||
i.byConnection[last] = userConnIndex
|
||||
|
||||
connectionID := wc.GetConnectionID()
|
||||
// Remove webconns from i.byChannelID
|
||||
// This has O(n) complexity. We are trading off speed while removing
|
||||
// a connection, to improve broadcasting a message.
|
||||
for chID, webConns := range i.byChannelID {
|
||||
// https://go.dev/wiki/SliceTricks#filtering-without-allocating
|
||||
filtered := webConns[:0]
|
||||
for _, conn := range webConns {
|
||||
if conn.GetConnectionID() != connectionID {
|
||||
filtered = append(filtered, conn)
|
||||
}
|
||||
}
|
||||
for i := len(filtered); i < len(webConns); i++ {
|
||||
webConns[i] = nil
|
||||
}
|
||||
i.byChannelID[chID] = filtered
|
||||
}
|
||||
|
||||
delete(i.byConnection, wc)
|
||||
delete(i.byConnectionId, wc.GetConnectionID())
|
||||
delete(i.byConnectionId, connectionID)
|
||||
}
|
||||
|
||||
func (i *hubConnectionIndex) InvalidateCMCacheForUser(userID string) error {
|
||||
// We make this query first to fail fast in case of an error.
|
||||
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), userID, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear out all user entries which belong to channels.
|
||||
for chID, webConns := range i.byChannelID {
|
||||
// https://go.dev/wiki/SliceTricks#filtering-without-allocating
|
||||
filtered := webConns[:0]
|
||||
for _, conn := range webConns {
|
||||
if conn.UserId != userID {
|
||||
filtered = append(filtered, conn)
|
||||
}
|
||||
}
|
||||
for i := len(filtered); i < len(webConns); i++ {
|
||||
webConns[i] = nil
|
||||
}
|
||||
i.byChannelID[chID] = filtered
|
||||
}
|
||||
|
||||
// re-populate the cache
|
||||
for chID := range cm {
|
||||
i.byChannelID[chID] = append(i.byChannelID[chID], i.ForUser(userID)...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *hubConnectionIndex) Has(wc *WebConn) bool {
|
||||
@ -691,6 +795,16 @@ func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
|
||||
return conns
|
||||
}
|
||||
|
||||
// ForChannel returns all connections for a channelID.
|
||||
func (i *hubConnectionIndex) ForChannel(channelID string) []*WebConn {
|
||||
// Note: this is expensive because usually there will be
|
||||
// more than 1 member for a channel, and broadcasting
|
||||
// is a hot path, but worth it.
|
||||
conns := make([]*WebConn, len(i.byChannelID[channelID]))
|
||||
copy(conns, i.byChannelID[channelID])
|
||||
return conns
|
||||
}
|
||||
|
||||
// ForUserActiveCount returns the number of active connections for a userID
|
||||
func (i *hubConnectionIndex) ForUserActiveCount(id string) int {
|
||||
cnt := 0
|
||||
|
@ -15,13 +15,11 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
||||
platform_mocks "github.com/mattermost/mattermost/server/v8/channels/app/platform/mocks"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/testlib"
|
||||
)
|
||||
|
||||
@ -53,7 +51,7 @@ func registerDummyWebConn(t *testing.T, th *TestHelper, addr net.Addr, session *
|
||||
Locale: "en",
|
||||
}
|
||||
wc := th.Service.NewWebConn(cfg, th.Suite, &hookRunner{})
|
||||
th.Service.HubRegister(wc)
|
||||
require.NoError(t, th.Service.HubRegister(wc))
|
||||
go wc.Pump()
|
||||
return wc
|
||||
}
|
||||
@ -105,8 +103,8 @@ func TestHubStopRaceCondition(t *testing.T) {
|
||||
go func() {
|
||||
wc4 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
wc5 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
hub.Register(wc4)
|
||||
hub.Register(wc5)
|
||||
require.NoError(t, hub.Register(wc4))
|
||||
require.NoError(t, hub.Register(wc5))
|
||||
|
||||
hub.UpdateActivity("userId", "sessionToken", 0)
|
||||
|
||||
@ -128,50 +126,9 @@ func TestHubStopRaceCondition(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHubSessionRevokeRace(t *testing.T) {
|
||||
th := SetupWithStoreMock(t)
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
sess1 := &model.Session{
|
||||
Id: "id1",
|
||||
UserId: "user1",
|
||||
DeviceId: "",
|
||||
Token: "sesstoken",
|
||||
ExpiresAt: model.GetMillis() + 300000,
|
||||
LastActivityAt: 10000,
|
||||
}
|
||||
|
||||
mockStore := th.Service.Store.(*mocks.Store)
|
||||
|
||||
mockUserStore := mocks.UserStore{}
|
||||
mockUserStore.On("Count", mock.Anything).Return(int64(10), nil)
|
||||
mockUserStore.On("GetUnreadCount", mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(int64(1), nil)
|
||||
mockPostStore := mocks.PostStore{}
|
||||
mockPostStore.On("GetMaxPostSize").Return(65535, nil)
|
||||
mockSystemStore := mocks.SystemStore{}
|
||||
mockSystemStore.On("GetByName", "UpgradedFromTE").Return(&model.System{Name: "UpgradedFromTE", Value: "false"}, nil)
|
||||
mockSystemStore.On("GetByName", "InstallationDate").Return(&model.System{Name: "InstallationDate", Value: "10"}, nil)
|
||||
mockSystemStore.On("GetByName", "FirstServerRunTimestamp").Return(&model.System{Name: "FirstServerRunTimestamp", Value: "10"}, nil)
|
||||
|
||||
mockSessionStore := mocks.SessionStore{}
|
||||
mockSessionStore.On("UpdateLastActivityAt", "id1", mock.Anything).Return(nil)
|
||||
mockSessionStore.On("Save", mock.AnythingOfType("*request.Context"), mock.AnythingOfType("*model.Session")).Return(sess1, nil)
|
||||
mockSessionStore.On("Get", mock.AnythingOfType("*request.Context"), mock.Anything, "id1").Return(sess1, nil)
|
||||
mockSessionStore.On("Remove", "id1").Return(nil)
|
||||
|
||||
mockStatusStore := mocks.StatusStore{}
|
||||
mockStatusStore.On("Get", "user1").Return(&model.Status{UserId: "user1", Status: model.StatusOnline}, nil)
|
||||
mockStatusStore.On("UpdateLastActivityAt", "user1", mock.Anything).Return(nil)
|
||||
mockStatusStore.On("SaveOrUpdate", mock.AnythingOfType("*model.Status")).Return(nil)
|
||||
|
||||
mockOAuthStore := mocks.OAuthStore{}
|
||||
mockStore.On("Session").Return(&mockSessionStore)
|
||||
mockStore.On("OAuth").Return(&mockOAuthStore)
|
||||
mockStore.On("Status").Return(&mockStatusStore)
|
||||
mockStore.On("User").Return(&mockUserStore)
|
||||
mockStore.On("Post").Return(&mockPostStore)
|
||||
mockStore.On("System").Return(&mockSystemStore)
|
||||
mockStore.On("GetDBSchemaVersion").Return(1, nil)
|
||||
|
||||
// This needs to be false for the condition to trigger
|
||||
th.Service.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.ExtendSessionLengthWithActivity = false
|
||||
@ -181,7 +138,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
session, err := th.Service.CreateSession(th.Context, &model.Session{
|
||||
UserId: "testid",
|
||||
UserId: model.NewId(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -193,7 +150,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
|
||||
time.Sleep(2 * time.Second)
|
||||
// We override the LastActivityAt which happens in NewWebConn.
|
||||
// This is needed to call RevokeSessionById which triggers the race.
|
||||
th.Service.AddSessionToCache(sess1)
|
||||
th.Service.AddSessionToCache(session)
|
||||
|
||||
go func() {
|
||||
for i := 0; i <= broadcastQueueSize; i++ {
|
||||
@ -203,7 +160,7 @@ func TestHubSessionRevokeRace(t *testing.T) {
|
||||
}()
|
||||
|
||||
// This call should happen _after_ !wc.IsAuthenticated() and _before_wc.isMemberOfTeam().
|
||||
// There's no guarantee this will happen. But that's out best bet to trigger this race.
|
||||
// There's no guarantee this will happen. But that's our best bet to trigger this race.
|
||||
wc1.InvalidateCache()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
@ -220,92 +177,236 @@ func TestHubSessionRevokeRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHubConnIndex(t *testing.T) {
|
||||
th := Setup(t)
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
|
||||
// User1
|
||||
wc1 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc1.SetConnectionID(model.NewId())
|
||||
wc1.SetSession(&model.Session{})
|
||||
|
||||
// User2
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc2.SetConnectionID(model.NewId())
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID(model.NewId())
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
wc4 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc4.SetConnectionID(model.NewId())
|
||||
wc4.SetSession(&model.Session{})
|
||||
|
||||
connIndex.Add(wc1)
|
||||
connIndex.Add(wc2)
|
||||
connIndex.Add(wc3)
|
||||
connIndex.Add(wc4)
|
||||
_, err := th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
UserId: th.BasicUser.Id,
|
||||
NotifyProps: model.GetDefaultChannelNotifyProps(),
|
||||
SchemeGuest: th.BasicUser.IsGuest(),
|
||||
SchemeUser: !th.BasicUser.IsGuest(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
UserId: th.BasicUser2.Id,
|
||||
NotifyProps: model.GetDefaultChannelNotifyProps(),
|
||||
SchemeGuest: th.BasicUser2.IsGuest(),
|
||||
SchemeUser: !th.BasicUser2.IsGuest(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Basic", func(t *testing.T) {
|
||||
assert.True(t, connIndex.Has(wc1))
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc3, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.True(t, connIndex.Has(wc1))
|
||||
assert.Len(t, connIndex.All(), 4)
|
||||
// User1
|
||||
wc1 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc1.SetConnectionID(model.NewId())
|
||||
wc1.SetSession(&model.Session{})
|
||||
|
||||
// User2
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc2.SetConnectionID(model.NewId())
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID(model.NewId())
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
wc4 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc4.SetConnectionID(model.NewId())
|
||||
wc4.SetSession(&model.Session{})
|
||||
|
||||
connIndex.Add(wc1)
|
||||
connIndex.Add(wc2)
|
||||
connIndex.Add(wc3)
|
||||
connIndex.Add(wc4)
|
||||
|
||||
t.Run("Basic", func(t *testing.T) {
|
||||
assert.True(t, connIndex.Has(wc1))
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc3, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.True(t, connIndex.Has(wc1))
|
||||
assert.Len(t, connIndex.All(), 4)
|
||||
})
|
||||
|
||||
t.Run("RemoveMiddleUser2", func(t *testing.T) {
|
||||
connIndex.Remove(wc3) // Remove from middle from user2
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.True(t, connIndex.Has(wc4))
|
||||
assert.Len(t, connIndex.All(), 3)
|
||||
})
|
||||
|
||||
t.Run("RemoveUser1", func(t *testing.T) {
|
||||
connIndex.Remove(wc1) // Remove sole connection from user1
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
|
||||
assert.Len(t, connIndex.ForUser(wc1.UserId), 0)
|
||||
assert.Len(t, connIndex.All(), 2)
|
||||
assert.False(t, connIndex.Has(wc1))
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
})
|
||||
|
||||
t.Run("RemoveEndUser2", func(t *testing.T) {
|
||||
connIndex.Remove(wc4) // Remove from end from user2
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.False(t, connIndex.Has(wc4))
|
||||
assert.Len(t, connIndex.All(), 1)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("RemoveMiddleUser2", func(t *testing.T) {
|
||||
connIndex.Remove(wc3) // Remove from middle from user2
|
||||
t.Run("ByConnectionId", func(t *testing.T) {
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{wc1})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.True(t, connIndex.Has(wc4))
|
||||
assert.Len(t, connIndex.All(), 3)
|
||||
// User1
|
||||
wc1ID := model.NewId()
|
||||
wc1 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: th.BasicUser.Id,
|
||||
}
|
||||
wc1.SetConnectionID(wc1ID)
|
||||
wc1.SetSession(&model.Session{})
|
||||
|
||||
// User2
|
||||
wc2ID := model.NewId()
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: th.BasicUser2.Id,
|
||||
}
|
||||
wc2.SetConnectionID(wc2ID)
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
wc3ID := model.NewId()
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID(wc3ID)
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
t.Run("no connections", func(t *testing.T) {
|
||||
assert.False(t, connIndex.Has(wc1))
|
||||
assert.False(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.Empty(t, connIndex.byConnectionId)
|
||||
})
|
||||
|
||||
t.Run("adding", func(t *testing.T) {
|
||||
connIndex.Add(wc1)
|
||||
connIndex.Add(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 2)
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
|
||||
t.Run("removing", func(t *testing.T) {
|
||||
connIndex.Remove(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 1)
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("RemoveUser1", func(t *testing.T) {
|
||||
connIndex.Remove(wc1) // Remove sole connection from user1
|
||||
t.Run("ByChannelId", func(t *testing.T) {
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2, wc4})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
|
||||
assert.Len(t, connIndex.ForUser(wc1.UserId), 0)
|
||||
assert.Len(t, connIndex.All(), 2)
|
||||
assert.False(t, connIndex.Has(wc1))
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
})
|
||||
// User1
|
||||
wc1ID := model.NewId()
|
||||
wc1 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: th.BasicUser.Id,
|
||||
}
|
||||
wc1.SetConnectionID(wc1ID)
|
||||
wc1.SetSession(&model.Session{})
|
||||
|
||||
t.Run("RemoveEndUser2", func(t *testing.T) {
|
||||
connIndex.Remove(wc4) // Remove from end from user2
|
||||
// User2
|
||||
wc2ID := model.NewId()
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: th.BasicUser2.Id,
|
||||
}
|
||||
wc2.SetConnectionID(wc2ID)
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc2.UserId), []*WebConn{wc2})
|
||||
assert.ElementsMatch(t, connIndex.ForUser(wc1.UserId), []*WebConn{})
|
||||
assert.True(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.False(t, connIndex.Has(wc4))
|
||||
assert.Len(t, connIndex.All(), 1)
|
||||
wc3ID := model.NewId()
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID(wc3ID)
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
connIndex.Add(wc1)
|
||||
connIndex.Add(wc2)
|
||||
connIndex.Add(wc3)
|
||||
|
||||
t.Run("ForChannel", func(t *testing.T) {
|
||||
require.Len(t, connIndex.byChannelID, 1)
|
||||
require.Equal(t, []*WebConn{wc1, wc2, wc3}, connIndex.ForChannel(th.BasicChannel.Id))
|
||||
require.Len(t, connIndex.ForChannel("notexist"), 0)
|
||||
})
|
||||
|
||||
ch := th.CreateChannel(th.BasicTeam)
|
||||
_, err = th.Service.Store.Channel().SaveMember(th.Context, &model.ChannelMember{
|
||||
ChannelId: ch.Id,
|
||||
UserId: th.BasicUser2.Id,
|
||||
NotifyProps: model.GetDefaultChannelNotifyProps(),
|
||||
SchemeGuest: th.BasicUser2.IsGuest(),
|
||||
SchemeUser: !th.BasicUser2.IsGuest(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("InvalidateCMCacheForUser", func(t *testing.T) {
|
||||
require.NoError(t, connIndex.InvalidateCMCacheForUser(th.BasicUser2.Id))
|
||||
require.Len(t, connIndex.byChannelID, 2)
|
||||
require.Len(t, connIndex.ForChannel(th.BasicChannel.Id), 3)
|
||||
require.Len(t, connIndex.ForChannel(ch.Id), 2)
|
||||
})
|
||||
|
||||
t.Run("Remove", func(t *testing.T) {
|
||||
connIndex.Remove(wc3)
|
||||
require.Len(t, connIndex.byChannelID, 2)
|
||||
require.Len(t, connIndex.ForChannel(th.BasicChannel.Id), 2)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -313,7 +414,7 @@ func TestHubConnIndexIncorrectRemoval(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
// User2
|
||||
wc2 := &WebConn{
|
||||
@ -356,73 +457,11 @@ func TestHubConnIndexIncorrectRemoval(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubConnIndexByConnectionId(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
|
||||
// User1
|
||||
wc1ID := model.NewId()
|
||||
wc1 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc1.SetConnectionID(wc1ID)
|
||||
wc1.SetSession(&model.Session{})
|
||||
|
||||
// User2
|
||||
wc2ID := model.NewId()
|
||||
wc2 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: model.NewId(),
|
||||
}
|
||||
wc2.SetConnectionID(wc2ID)
|
||||
wc2.SetSession(&model.Session{})
|
||||
|
||||
wc3ID := model.NewId()
|
||||
wc3 := &WebConn{
|
||||
Platform: th.Service,
|
||||
Suite: th.Suite,
|
||||
UserId: wc2.UserId,
|
||||
}
|
||||
wc3.SetConnectionID(wc3ID)
|
||||
wc3.SetSession(&model.Session{})
|
||||
|
||||
t.Run("no connections", func(t *testing.T) {
|
||||
assert.False(t, connIndex.Has(wc1))
|
||||
assert.False(t, connIndex.Has(wc2))
|
||||
assert.False(t, connIndex.Has(wc3))
|
||||
assert.Empty(t, connIndex.byConnectionId)
|
||||
})
|
||||
|
||||
t.Run("adding", func(t *testing.T) {
|
||||
connIndex.Add(wc1)
|
||||
connIndex.Add(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 2)
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, wc3, connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
|
||||
t.Run("removing", func(t *testing.T) {
|
||||
connIndex.Remove(wc3)
|
||||
|
||||
assert.Len(t, connIndex.byConnectionId, 1)
|
||||
assert.Equal(t, wc1, connIndex.ForConnection(wc1ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc3ID))
|
||||
assert.Equal(t, (*WebConn)(nil), connIndex.ForConnection(wc2ID))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHubConnIndexInactive(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(2 * time.Second)
|
||||
connIndex := newHubConnectionIndex(2*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
// User1
|
||||
wc1 := &WebConn{
|
||||
@ -582,7 +621,7 @@ func TestHubWebConnCount(t *testing.T) {
|
||||
func BenchmarkHubConnIndex(b *testing.B) {
|
||||
th := Setup(b).InitBasic()
|
||||
defer th.TearDown()
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
// User1
|
||||
wc1 := &WebConn{
|
||||
@ -627,7 +666,7 @@ func TestHubConnIndexRemoveMemLeak(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
connIndex := newHubConnectionIndex(1 * time.Second)
|
||||
connIndex := newHubConnectionIndex(1*time.Second, th.Service.Store, th.Service.logger)
|
||||
|
||||
wc := &WebConn{
|
||||
Platform: th.Service,
|
||||
|
@ -49,6 +49,7 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque
|
||||
|
||||
session, err := conn.Suite.GetSession(token)
|
||||
if err != nil {
|
||||
conn.Platform.Log().Warn("Error while getting session token", mlog.Err(err))
|
||||
conn.WebSocket.Close()
|
||||
return
|
||||
}
|
||||
@ -56,7 +57,12 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque
|
||||
conn.SetSessionToken(session.Token)
|
||||
conn.UserId = session.UserId
|
||||
|
||||
conn.Platform.HubRegister(conn)
|
||||
nErr := conn.Platform.HubRegister(conn)
|
||||
if nErr != nil {
|
||||
conn.Platform.Log().Error("Error while registering to hub", mlog.String("user_id", conn.UserId), mlog.Err(nErr))
|
||||
conn.WebSocket.Close()
|
||||
return
|
||||
}
|
||||
|
||||
conn.Platform.Go(func() {
|
||||
conn.Platform.SetStatusOnline(session.UserId, false)
|
||||
|
@ -15,6 +15,9 @@ import (
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||
)
|
||||
|
||||
// TestWebConnShouldSendEvent is not exhaustive because some of the checks
|
||||
// happen inside web_hub.go before the event is actually broadcasted, and checked
|
||||
// via ShouldSendEvent.
|
||||
func TestWebConnShouldSendEvent(t *testing.T) {
|
||||
os.Setenv("MM_FEATUREFLAGS_WEBSOCKETEVENTSCOPE", "true")
|
||||
defer os.Unsetenv("MM_FEATUREFLAGS_WEBSOCKETEVENTSCOPE")
|
||||
@ -157,14 +160,6 @@ func TestWebConnShouldSendEvent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("should send to basic user in basic channel", func(t *testing.T) {
|
||||
event = event.SetBroadcast(&model.WebsocketBroadcast{ChannelId: th.BasicChannel.Id})
|
||||
|
||||
assert.True(t, basicUserWc.ShouldSendEvent(event), "expected user 1")
|
||||
assert.False(t, basicUser2Wc.ShouldSendEvent(event), "did not expect user 2")
|
||||
assert.False(t, adminUserWc.ShouldSendEvent(event), "did not expect admin")
|
||||
})
|
||||
|
||||
t.Run("should not send typing event unless in scope", func(t *testing.T) {
|
||||
event2 := model.NewWebSocketEvent(model.WebsocketEventTyping, "", th.BasicChannel.Id, "", nil, "")
|
||||
// Basic, unset case
|
||||
@ -222,14 +217,6 @@ func TestWebConnShouldSendEvent(t *testing.T) {
|
||||
assert.False(t, basicUserWc.ShouldSendEvent(event2))
|
||||
})
|
||||
|
||||
t.Run("should send to basic user and admin in channel2", func(t *testing.T) {
|
||||
event = event.SetBroadcast(&model.WebsocketBroadcast{ChannelId: channel2.Id})
|
||||
|
||||
assert.True(t, basicUserWc.ShouldSendEvent(event), "expected user 1")
|
||||
assert.False(t, basicUser2Wc.ShouldSendEvent(event), "did not expect user 2")
|
||||
assert.True(t, adminUserWc.ShouldSendEvent(event), "expected admin")
|
||||
})
|
||||
|
||||
t.Run("channel member cache invalidated after user added to channel", func(t *testing.T) {
|
||||
th.AddUserToChannel(th.BasicUser2, channel2)
|
||||
basicUser2Wc.InvalidateCache()
|
||||
|
@ -16,16 +16,6 @@ func (a *App) GetHubForUserId(userID string) *platform.Hub {
|
||||
return a.Srv().Platform().GetHubForUserId(userID)
|
||||
}
|
||||
|
||||
// HubRegister registers a connection to a hub.
|
||||
func (a *App) HubRegister(webConn *platform.WebConn) {
|
||||
a.Srv().Platform().HubRegister(webConn)
|
||||
}
|
||||
|
||||
// HubUnregister unregisters a connection from a hub.
|
||||
func (a *App) HubUnregister(webConn *platform.WebConn) {
|
||||
a.Srv().Platform().HubUnregister(webConn)
|
||||
}
|
||||
|
||||
func (a *App) Publish(message *model.WebSocketEvent) {
|
||||
a.Srv().Platform().Publish(message)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user