[MM-58355] Send invalidate cache message across the cluster so that websocket connections on other instances are invalidated correctly (#27204)

* [MM-58355] Send invalidate cache message across the cluster so that websocket connections on other instances are invalidated correctly

* Add suggestion to clear the session cache on the local node as well

* Force read from master DB when gettting channel members for websocket to avoid any DB sync issues

* PR feedback

* Missed generated files
This commit is contained in:
Devin Binnie 2024-06-07 09:38:53 -04:00 committed by GitHub
parent 91741a7fa4
commit f3e760008c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 55 additions and 57 deletions

View File

@ -82,7 +82,7 @@ func (a *App) SessionHasPermissionToChannel(c request.CTX, session model.Session
return false
}
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(session.UserId, true, true)
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(c, session.UserId, true, true)
var channelRoles []string
if err == nil {
if roles, ok := ids[channelID]; ok {
@ -134,7 +134,7 @@ func (a *App) SessionHasPermissionToChannels(c request.CTX, session model.Sessio
return true
}
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(session.UserId, true, true)
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(c, session.UserId, true, true)
var channelRoles []string
for _, channelID := range channelIDs {
if err == nil {
@ -266,7 +266,7 @@ func (a *App) HasPermissionToChannel(c request.CTX, askingUserId string, channel
// We call GetAllChannelMembersForUser instead of just getting
// a single member from the DB, because it's cache backed
// and this is a very frequent call.
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(askingUserId, true, true)
ids, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(c, askingUserId, true, true)
var channelRoles []string
if err == nil {
if roles, ok := ids[channelID]; ok {

View File

@ -152,7 +152,7 @@ func TestSessionHasPermissionToChannel(t *testing.T) {
mockChannelStore := mocks.ChannelStore{}
mockChannelStore.On("Get", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("arbitrary error"))
mockChannelStore.On("GetAllChannelMembersForUser", mock.Anything, mock.Anything, mock.Anything).Return(th.App.Srv().Store().Channel().GetAllChannelMembersForUser(th.BasicUser.Id, false, false))
mockChannelStore.On("GetAllChannelMembersForUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(th.App.Srv().Store().Channel().GetAllChannelMembersForUser(th.Context, th.BasicUser.Id, false, false))
mockChannelStore.On("ClearCaches").Return()
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("FileInfo").Return(th.App.Srv().Store().FileInfo())
@ -214,7 +214,7 @@ func TestSessionHasPermissionToChannels(t *testing.T) {
mockStore := mocks.Store{}
mockChannelStore := mocks.ChannelStore{}
mockChannelStore.On("Get", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("arbitrary error"))
mockChannelStore.On("GetAllChannelMembersForUser", mock.Anything, mock.Anything, mock.Anything).Return(th.App.Srv().Store().Channel().GetAllChannelMembersForUser(th.BasicUser.Id, false, false))
mockChannelStore.On("GetAllChannelMembersForUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(th.App.Srv().Store().Channel().GetAllChannelMembersForUser(th.Context, th.BasicUser.Id, false, false))
mockChannelStore.On("ClearCaches").Return()
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("FileInfo").Return(th.App.Srv().Store().FileInfo())

View File

@ -1597,19 +1597,11 @@ func (a *App) AddUserToChannel(c request.CTX, user *model.User, channel *model.C
return nil, err
}
// We are sending separate websocket events to the user added and to the channel
// This is to get around potential cluster syncing issues where other nodes may not receive the most up to date channel members
// There is likely some issue syncing these that needs to be looked at, but this is the current fix.
message := model.NewWebSocketEvent(model.WebsocketEventUserAdded, "", channel.Id, "", map[string]bool{user.Id: true}, "")
message := model.NewWebSocketEvent(model.WebsocketEventUserAdded, "", channel.Id, "", nil, "")
message.Add("user_id", user.Id)
message.Add("team_id", channel.TeamId)
a.Publish(message)
userMessage := model.NewWebSocketEvent(model.WebsocketEventUserAdded, "", channel.Id, user.Id, nil, "")
userMessage.Add("user_id", user.Id)
userMessage.Add("team_id", channel.TeamId)
a.Publish(userMessage)
return newMember, nil
}

View File

@ -70,7 +70,7 @@ func (ps *PlatformService) ClearAllUsersSessionCacheLocal() {
}
func (ps *PlatformService) ClearUserSessionCache(userID string) {
ps.ClearUserSessionCacheLocal(userID)
ps.ClearSessionCacheForUserSkipClusterSend(userID)
if ps.clusterIFace != nil {
msg := &model.ClusterMessage{

View File

@ -26,6 +26,7 @@ import (
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
)
const (
@ -895,7 +896,12 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
}
if wc.allChannelMembers == nil {
result, err := wc.Platform.Store.Channel().GetAllChannelMembersForUser(wc.UserId, false, false)
result, err := wc.Platform.Store.Channel().GetAllChannelMembersForUser(
sqlstore.RequestContextWithMaster(request.EmptyContext(wc.Platform.logger)),
wc.UserId,
false,
false,
)
if err != nil {
mlog.Error("webhub.shouldSendEvent.", mlog.Err(err))
return false

View File

@ -197,7 +197,7 @@ func (ps *PlatformService) InvalidateCacheForChannelPosts(channelID string) {
func (ps *PlatformService) InvalidateCacheForUser(userID string) {
ps.Store.Channel().InvalidateAllChannelMembersForUser(userID)
ps.invalidateWebConnSessionCacheForUser(userID)
ps.ClearUserSessionCache(userID)
ps.Store.User().InvalidateProfilesInChannelCacheByUser(userID)
ps.Store.User().InvalidateProfileCacheForUser(userID)

View File

@ -2388,7 +2388,7 @@ func (a *App) GetViewUsersRestrictions(c request.CTX, userID string) (*model.Vie
}
}
userChannelMembers, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(userID, true, true)
userChannelMembers, err := a.Srv().Store().Channel().GetAllChannelMembersForUser(c, userID, true, true)
if err != nil {
return nil, model.NewAppError("GetViewUsersRestrictions", "app.channel.get_channels.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
}

View File

@ -268,7 +268,7 @@ func (s LocalCacheChannelStore) GetMany(ids []string, allowFromCache bool) (mode
return append(foundChannels, channels...), nil
}
func (s LocalCacheChannelStore) GetAllChannelMembersForUser(userId string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
func (s LocalCacheChannelStore) GetAllChannelMembersForUser(ctx request.CTX, userId string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
cache_key := userId
if includeDeleted {
cache_key += "_deleted"
@ -280,7 +280,7 @@ func (s LocalCacheChannelStore) GetAllChannelMembersForUser(userId string, allow
}
}
ids, err := s.ChannelStore.GetAllChannelMembersForUser(userId, allowFromCache, includeDeleted)
ids, err := s.ChannelStore.GetAllChannelMembersForUser(ctx, userId, allowFromCache, includeDeleted)
if err != nil {
return nil, err
}

View File

@ -983,7 +983,7 @@ func (s *OpenTracingLayerChannelStore) GetAllChannelMemberIdsByChannelId(id stri
return result, err
}
func (s *OpenTracingLayerChannelStore) GetAllChannelMembersForUser(userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
func (s *OpenTracingLayerChannelStore) GetAllChannelMembersForUser(ctx request.CTX, userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GetAllChannelMembersForUser")
s.Root.Store.SetContext(newCtx)
@ -992,7 +992,7 @@ func (s *OpenTracingLayerChannelStore) GetAllChannelMembersForUser(userID string
}()
defer span.Finish()
result, err := s.ChannelStore.GetAllChannelMembersForUser(userID, allowFromCache, includeDeleted)
result, err := s.ChannelStore.GetAllChannelMembersForUser(ctx, userID, allowFromCache, includeDeleted)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
@ -11498,7 +11498,7 @@ func (s *OpenTracingLayerUserStore) GetAllProfiles(options *model.UserGetOptions
return result, err
}
func (s *OpenTracingLayerUserStore) GetAllProfilesInChannel(ctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
func (s *OpenTracingLayerUserStore) GetAllProfilesInChannel(rctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "UserStore.GetAllProfilesInChannel")
s.Root.Store.SetContext(newCtx)
@ -11507,7 +11507,7 @@ func (s *OpenTracingLayerUserStore) GetAllProfilesInChannel(ctx context.Context,
}()
defer span.Finish()
result, err := s.UserStore.GetAllProfilesInChannel(ctx, channelID, allowFromCache)
result, err := s.UserStore.GetAllProfilesInChannel(rctx, channelID, allowFromCache)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)

View File

@ -1070,11 +1070,11 @@ func (s *RetryLayerChannelStore) GetAllChannelMemberIdsByChannelId(id string) ([
}
func (s *RetryLayerChannelStore) GetAllChannelMembersForUser(userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
func (s *RetryLayerChannelStore) GetAllChannelMembersForUser(ctx request.CTX, userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
tries := 0
for {
result, err := s.ChannelStore.GetAllChannelMembersForUser(userID, allowFromCache, includeDeleted)
result, err := s.ChannelStore.GetAllChannelMembersForUser(ctx, userID, allowFromCache, includeDeleted)
if err == nil {
return result, nil
}
@ -13154,11 +13154,11 @@ func (s *RetryLayerUserStore) GetAllProfiles(options *model.UserGetOptions) ([]*
}
func (s *RetryLayerUserStore) GetAllProfilesInChannel(ctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
func (s *RetryLayerUserStore) GetAllProfilesInChannel(rctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
tries := 0
for {
result, err := s.UserStore.GetAllProfilesInChannel(ctx, channelID, allowFromCache)
result, err := s.UserStore.GetAllProfilesInChannel(rctx, channelID, allowFromCache)
if err == nil {
return result, nil
}

View File

@ -90,7 +90,7 @@ func (s *SearchStore) indexUser(rctx request.CTX, user *model.User) {
userTeamsIds = append(userTeamsIds, team.Id)
}
userChannelMembers, err := s.Channel().GetAllChannelMembersForUser(user.Id, false, true)
userChannelMembers, err := s.Channel().GetAllChannelMembersForUser(rctx, user.Id, false, true)
if err != nil {
rctx.Logger().Error("Encountered error indexing user", mlog.String("user_id", user.Id), mlog.String("search_engine", engineCopy.GetName()), mlog.Err(err))
return

View File

@ -2125,7 +2125,7 @@ func (s SqlChannelStore) GetMemberForPost(postId string, userId string, includeA
return dbMember.ToModel(), nil
}
func (s SqlChannelStore) GetAllChannelMembersForUser(userId string, allowFromCache bool, includeDeleted bool) (_ map[string]string, err error) {
func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId string, allowFromCache bool, includeDeleted bool) (_ map[string]string, err error) {
query := s.getQueryBuilder().
Select(`
ChannelMembers.ChannelId, ChannelMembers.Roles, ChannelMembers.SchemeGuest,
@ -2151,7 +2151,7 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(userId string, allowFromCac
return nil, errors.Wrap(err, "channel_tosql")
}
rows, err := s.GetReplicaX().DB.Query(queryString, args...)
rows, err := s.SqlStore.DBXFromContext(rctx.Context()).Query(queryString, args...)
if err != nil {
return nil, errors.Wrap(err, "failed to find ChannelMembers, TeamScheme and ChannelScheme data")
}

View File

@ -229,7 +229,7 @@ type ChannelStore interface {
GetMember(ctx context.Context, channelID string, userID string) (*model.ChannelMember, error)
GetMemberLastViewedAt(ctx context.Context, channelID string, userID string) (int64, error)
GetChannelMembersTimezones(channelID string) ([]model.StringMap, error)
GetAllChannelMembersForUser(userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error)
GetAllChannelMembersForUser(ctx request.CTX, userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error)
GetChannelsMemberCount(channelIDs []string) (map[string]int64, error)
InvalidateAllChannelMembersForUser(userID string)
GetAllChannelMembersNotifyPropsForChannel(channelID string, allowFromCache bool) (map[string]model.StringMap, error)
@ -422,7 +422,7 @@ type UserStore interface {
GetProfilesInChannel(options *model.UserGetOptions) ([]*model.User, error)
GetProfilesInChannelByStatus(options *model.UserGetOptions) ([]*model.User, error)
GetProfilesInChannelByAdmin(options *model.UserGetOptions) ([]*model.User, error)
GetAllProfilesInChannel(ctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error)
GetAllProfilesInChannel(rctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error)
GetProfilesNotInChannel(teamID string, channelId string, groupConstrained bool, offset int, limit int, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, error)
GetProfilesWithoutTeam(options *model.UserGetOptions) ([]*model.User, error)
GetProfilesByUsernames(usernames []string, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, error)

View File

@ -3646,27 +3646,27 @@ func testChannelStoreGetChannels(t *testing.T, rctx request.CTX, ss store.Store)
require.Equal(t, o2.Id, list[1].Id, "missing channel")
require.Equal(t, o3.Id, list[2].Id, "missing channel")
ids, err := ss.Channel().GetAllChannelMembersForUser(m1.UserId, false, false)
ids, err := ss.Channel().GetAllChannelMembersForUser(rctx, m1.UserId, false, false)
require.NoError(t, err)
_, ok := ids[o1.Id]
require.True(t, ok, "missing channel")
ids2, err := ss.Channel().GetAllChannelMembersForUser(m1.UserId, true, false)
ids2, err := ss.Channel().GetAllChannelMembersForUser(rctx, m1.UserId, true, false)
require.NoError(t, err)
_, ok = ids2[o1.Id]
require.True(t, ok, "missing channel")
ids3, err := ss.Channel().GetAllChannelMembersForUser(m1.UserId, true, false)
ids3, err := ss.Channel().GetAllChannelMembersForUser(rctx, m1.UserId, true, false)
require.NoError(t, err)
_, ok = ids3[o1.Id]
require.True(t, ok, "missing channel")
ids4, err := ss.Channel().GetAllChannelMembersForUser(m1.UserId, true, true)
ids4, err := ss.Channel().GetAllChannelMembersForUser(rctx, m1.UserId, true, true)
require.NoError(t, err)
_, ok = ids4[o1.Id]
require.True(t, ok, "missing channel")
ids5, err := ss.Channel().GetAllChannelMembersForUser(model.NewId(), true, true)
ids5, err := ss.Channel().GetAllChannelMembersForUser(rctx, model.NewId(), true, true)
require.NoError(t, err)
require.True(t, len(ids5) == 0)

View File

@ -534,9 +534,9 @@ func (_m *ChannelStore) GetAllChannelMemberIdsByChannelId(id string) ([]string,
return r0, r1
}
// GetAllChannelMembersForUser provides a mock function with given fields: userID, allowFromCache, includeDeleted
func (_m *ChannelStore) GetAllChannelMembersForUser(userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
ret := _m.Called(userID, allowFromCache, includeDeleted)
// GetAllChannelMembersForUser provides a mock function with given fields: ctx, userID, allowFromCache, includeDeleted
func (_m *ChannelStore) GetAllChannelMembersForUser(ctx request.CTX, userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
ret := _m.Called(ctx, userID, allowFromCache, includeDeleted)
if len(ret) == 0 {
panic("no return value specified for GetAllChannelMembersForUser")
@ -544,19 +544,19 @@ func (_m *ChannelStore) GetAllChannelMembersForUser(userID string, allowFromCach
var r0 map[string]string
var r1 error
if rf, ok := ret.Get(0).(func(string, bool, bool) (map[string]string, error)); ok {
return rf(userID, allowFromCache, includeDeleted)
if rf, ok := ret.Get(0).(func(request.CTX, string, bool, bool) (map[string]string, error)); ok {
return rf(ctx, userID, allowFromCache, includeDeleted)
}
if rf, ok := ret.Get(0).(func(string, bool, bool) map[string]string); ok {
r0 = rf(userID, allowFromCache, includeDeleted)
if rf, ok := ret.Get(0).(func(request.CTX, string, bool, bool) map[string]string); ok {
r0 = rf(ctx, userID, allowFromCache, includeDeleted)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]string)
}
}
if rf, ok := ret.Get(1).(func(string, bool, bool) error); ok {
r1 = rf(userID, allowFromCache, includeDeleted)
if rf, ok := ret.Get(1).(func(request.CTX, string, bool, bool) error); ok {
r1 = rf(ctx, userID, allowFromCache, includeDeleted)
} else {
r1 = ret.Error(1)
}

View File

@ -479,9 +479,9 @@ func (_m *UserStore) GetAllProfiles(options *model.UserGetOptions) ([]*model.Use
return r0, r1
}
// GetAllProfilesInChannel provides a mock function with given fields: ctx, channelID, allowFromCache
func (_m *UserStore) GetAllProfilesInChannel(ctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
ret := _m.Called(ctx, channelID, allowFromCache)
// GetAllProfilesInChannel provides a mock function with given fields: rctx, channelID, allowFromCache
func (_m *UserStore) GetAllProfilesInChannel(rctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
ret := _m.Called(rctx, channelID, allowFromCache)
if len(ret) == 0 {
panic("no return value specified for GetAllProfilesInChannel")
@ -490,10 +490,10 @@ func (_m *UserStore) GetAllProfilesInChannel(ctx context.Context, channelID stri
var r0 map[string]*model.User
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, bool) (map[string]*model.User, error)); ok {
return rf(ctx, channelID, allowFromCache)
return rf(rctx, channelID, allowFromCache)
}
if rf, ok := ret.Get(0).(func(context.Context, string, bool) map[string]*model.User); ok {
r0 = rf(ctx, channelID, allowFromCache)
r0 = rf(rctx, channelID, allowFromCache)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]*model.User)
@ -501,7 +501,7 @@ func (_m *UserStore) GetAllProfilesInChannel(ctx context.Context, channelID stri
}
if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
r1 = rf(ctx, channelID, allowFromCache)
r1 = rf(rctx, channelID, allowFromCache)
} else {
r1 = ret.Error(1)
}

View File

@ -933,10 +933,10 @@ func (s *TimerLayerChannelStore) GetAllChannelMemberIdsByChannelId(id string) ([
return result, err
}
func (s *TimerLayerChannelStore) GetAllChannelMembersForUser(userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
func (s *TimerLayerChannelStore) GetAllChannelMembersForUser(ctx request.CTX, userID string, allowFromCache bool, includeDeleted bool) (map[string]string, error) {
start := time.Now()
result, err := s.ChannelStore.GetAllChannelMembersForUser(userID, allowFromCache, includeDeleted)
result, err := s.ChannelStore.GetAllChannelMembersForUser(ctx, userID, allowFromCache, includeDeleted)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@ -10341,10 +10341,10 @@ func (s *TimerLayerUserStore) GetAllProfiles(options *model.UserGetOptions) ([]*
return result, err
}
func (s *TimerLayerUserStore) GetAllProfilesInChannel(ctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
func (s *TimerLayerUserStore) GetAllProfilesInChannel(rctx context.Context, channelID string, allowFromCache bool) (map[string]*model.User, error) {
start := time.Now()
result, err := s.UserStore.GetAllProfilesInChannel(ctx, channelID, allowFromCache)
result, err := s.UserStore.GetAllProfilesInChannel(rctx, channelID, allowFromCache)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {