diff --git a/app/app_iface.go b/app/app_iface.go index 3ee2632dd5..9e84732fa8 100644 --- a/app/app_iface.go +++ b/app/app_iface.go @@ -1071,7 +1071,7 @@ type AppIface interface { UpdateScheme(scheme *model.Scheme) (*model.Scheme, *model.AppError) UpdateSessionsIsGuest(userID string, isGuest bool) UpdateSharedChannel(sc *model.SharedChannel) (*model.SharedChannel, error) - UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error + UpdateSharedChannelRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error UpdateSidebarCategories(userID, teamID string, categories []*model.SidebarCategoryWithChannels) ([]*model.SidebarCategoryWithChannels, *model.AppError) UpdateSidebarCategoryOrder(userID, teamID string, categoryOrder []string) *model.AppError UpdateTeam(team *model.Team) (*model.Team, *model.AppError) diff --git a/app/opentracing/opentracing_layer.go b/app/opentracing/opentracing_layer.go index 649dd1271a..defab3a207 100644 --- a/app/opentracing/opentracing_layer.go +++ b/app/opentracing/opentracing_layer.go @@ -16427,9 +16427,9 @@ func (a *OpenTracingAppLayer) UpdateSharedChannel(sc *model.SharedChannel) (*mod return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error { +func (a *OpenTracingAppLayer) UpdateSharedChannelRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { origCtx := a.ctx - span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateSharedChannelRemoteNextSyncAt") + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateSharedChannelRemoteCursor") a.ctx = newCtx a.app.Srv().Store.SetContext(newCtx) @@ -16439,7 +16439,7 @@ func (a *OpenTracingAppLayer) UpdateSharedChannelRemoteNextSyncAt(id string, syn }() defer span.Finish() - resultVar0 := a.app.UpdateSharedChannelRemoteNextSyncAt(id, syncTime) + resultVar0 := a.app.UpdateSharedChannelRemoteCursor(id, cursor) if resultVar0 != nil { span.LogFields(spanlog.Error(resultVar0)) diff --git a/app/post_test.go b/app/post_test.go index f7b9bed5d3..4105f6be12 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -2190,8 +2190,8 @@ func TestSharedChannelSyncForPostActions(t *testing.T) { }, channel, false, true) require.Nil(t, err, "Creating a post should not error") - assert.Len(t, remoteClusterService.notifications, 1) - assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) + assert.Len(t, remoteClusterService.channelNotifications, 1) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[0]) }) t.Run("updating a post in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { @@ -2217,9 +2217,9 @@ func TestSharedChannelSyncForPostActions(t *testing.T) { _, err = th.App.UpdatePost(th.Context, post, true) require.Nil(t, err, "Updating a post should not error") - assert.Len(t, remoteClusterService.notifications, 2) - assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) - assert.Equal(t, channel.Id, remoteClusterService.notifications[1]) + assert.Len(t, remoteClusterService.channelNotifications, 2) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[0]) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[1]) }) t.Run("deleting a post in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { @@ -2246,9 +2246,9 @@ func TestSharedChannelSyncForPostActions(t *testing.T) { require.Nil(t, err, "Deleting a post should not error") // one creation and two deletes - assert.Len(t, remoteClusterService.notifications, 3) - assert.Equal(t, channel.Id, remoteClusterService.notifications[0]) - assert.Equal(t, channel.Id, remoteClusterService.notifications[1]) - assert.Equal(t, channel.Id, remoteClusterService.notifications[2]) + assert.Len(t, remoteClusterService.channelNotifications, 3) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[0]) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[1]) + assert.Equal(t, channel.Id, remoteClusterService.channelNotifications[2]) }) } diff --git a/app/reaction_test.go b/app/reaction_test.go index 098ebb434f..93aaef4993 100644 --- a/app/reaction_test.go +++ b/app/reaction_test.go @@ -44,9 +44,9 @@ func TestSharedChannelSyncForReactionActions(t *testing.T) { th.TearDown() // We need to enforce teardown because reaction instrumentation happens in a goroutine - assert.Len(t, sharedChannelService.notifications, 2) - assert.Equal(t, channel.Id, sharedChannelService.notifications[0]) - assert.Equal(t, channel.Id, sharedChannelService.notifications[1]) + assert.Len(t, sharedChannelService.channelNotifications, 2) + assert.Equal(t, channel.Id, sharedChannelService.channelNotifications[0]) + assert.Equal(t, channel.Id, sharedChannelService.channelNotifications[1]) }) t.Run("removing a reaction in a shared channel performs a content sync when sync service is running on that node", func(t *testing.T) { @@ -79,8 +79,8 @@ func TestSharedChannelSyncForReactionActions(t *testing.T) { th.TearDown() // We need to enforce teardown because reaction instrumentation happens in a goroutine - assert.Len(t, sharedChannelService.notifications, 2) - assert.Equal(t, channel.Id, sharedChannelService.notifications[0]) - assert.Equal(t, channel.Id, sharedChannelService.notifications[1]) + assert.Len(t, sharedChannelService.channelNotifications, 2) + assert.Equal(t, channel.Id, sharedChannelService.channelNotifications[0]) + assert.Equal(t, channel.Id, sharedChannelService.channelNotifications[1]) }) } diff --git a/app/server.go b/app/server.go index 319c17dff7..5f72fa9f31 100644 --- a/app/server.go +++ b/app/server.go @@ -159,6 +159,7 @@ type Server struct { telemetryService *telemetry.TelemetryService + serviceMux sync.RWMutex remoteClusterService remotecluster.RemoteClusterServiceIFace sharedChannelService SharedChannelServiceIFace @@ -873,16 +874,19 @@ func (s *Server) startInterClusterServices(license *model.License, app *App) err var err error - s.remoteClusterService, err = remotecluster.NewRemoteClusterService(s) + rcs, err := remotecluster.NewRemoteClusterService(s) if err != nil { return err } - if err = s.remoteClusterService.Start(); err != nil { - s.remoteClusterService = nil + if err = rcs.Start(); err != nil { return err } + s.serviceMux.Lock() + s.remoteClusterService = rcs + s.serviceMux.Unlock() + // Shared Channels service // License check @@ -897,15 +901,19 @@ func (s *Server) startInterClusterServices(license *model.License, app *App) err return nil } - s.sharedChannelService, err = sharedchannel.NewSharedChannelService(s, app) + scs, err := sharedchannel.NewSharedChannelService(s, app) if err != nil { return err } - if err = s.sharedChannelService.Start(); err != nil { - s.remoteClusterService = nil + if err = scs.Start(); err != nil { return err } + + s.serviceMux.Lock() + s.sharedChannelService = scs + s.serviceMux.Unlock() + return nil } @@ -967,11 +975,18 @@ func (s *Server) Shutdown() { mlog.Warn("Unable to cleanly shutdown telemetry client", mlog.Err(err)) } + s.serviceMux.RLock() + if s.sharedChannelService != nil { + if err = s.sharedChannelService.Shutdown(); err != nil { + mlog.Error("Error shutting down shared channel services", mlog.Err(err)) + } + } if s.remoteClusterService != nil { if err = s.remoteClusterService.Shutdown(); err != nil { mlog.Error("Error shutting down intercluster services", mlog.Err(err)) } } + s.serviceMux.RUnlock() s.StopHTTPServer() s.stopLocalModeServer() @@ -1992,12 +2007,16 @@ func (s *Server) GetStore() store.Store { // GetRemoteClusterService returns the `RemoteClusterService` instantiated by the server. // May be nil if the service is not enabled via license. func (s *Server) GetRemoteClusterService() remotecluster.RemoteClusterServiceIFace { + s.serviceMux.RLock() + defer s.serviceMux.RUnlock() return s.remoteClusterService } // GetSharedChannelSyncService returns the `SharedChannelSyncService` instantiated by the server. // May be nil if the service is not enabled via license. func (s *Server) GetSharedChannelSyncService() SharedChannelServiceIFace { + s.serviceMux.RLock() + defer s.serviceMux.RUnlock() return s.sharedChannelService } @@ -2010,12 +2029,16 @@ func (s *Server) GetMetrics() einterfaces.MetricsInterface { // SetRemoteClusterService sets the `RemoteClusterService` to be used by the server. // For testing only. func (s *Server) SetRemoteClusterService(remoteClusterService remotecluster.RemoteClusterServiceIFace) { + s.serviceMux.Lock() + defer s.serviceMux.Unlock() s.remoteClusterService = remoteClusterService } // SetSharedChannelSyncService sets the `SharedChannelSyncService` to be used by the server. // For testing only. func (s *Server) SetSharedChannelSyncService(sharedChannelService SharedChannelServiceIFace) { + s.serviceMux.Lock() + defer s.serviceMux.Unlock() s.sharedChannelService = sharedChannelService } diff --git a/app/shared_channel.go b/app/shared_channel.go index 5fdd64eeb4..6961ecae18 100644 --- a/app/shared_channel.go +++ b/app/shared_channel.go @@ -133,8 +133,8 @@ func (a *App) GetRemoteClusterForUser(remoteID string, userID string) (*model.Re return rc, nil } -func (a *App) UpdateSharedChannelRemoteNextSyncAt(id string, syncTime int64) error { - return a.Srv().Store.SharedChannel().UpdateRemoteNextSyncAt(id, syncTime) +func (a *App) UpdateSharedChannelRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { + return a.Srv().Store.SharedChannel().UpdateRemoteCursor(id, cursor) } func (a *App) DeleteSharedChannelRemote(id string) (bool, error) { @@ -153,3 +153,13 @@ func (a *App) GetSharedChannelRemotesStatus(channelID string) ([]*model.SharedCh func (a *App) NotifySharedChannelUserUpdate(user *model.User) { a.sendUpdatedUserEvent(*user) } + +// onUserProfileChange is called when a user's profile has changed +// (username, email, profile image, ...) +func (a *App) onUserProfileChange(userID string) { + syncService := a.Srv().GetSharedChannelSyncService() + if syncService == nil || !syncService.Active() { + return + } + syncService.NotifyUserProfileChanged(userID) +} diff --git a/app/shared_channel_notifier_test.go b/app/shared_channel_notifier_test.go index db75af42c0..5f56dab957 100644 --- a/app/shared_channel_notifier_test.go +++ b/app/shared_channel_notifier_test.go @@ -18,10 +18,10 @@ func TestServerSyncSharedChannelHandler(t *testing.T) { mockService := NewMockSharedChannelService(nil) mockService.active = false - th.App.srv.sharedChannelService = mockService + th.App.srv.SetSharedChannelSyncService(mockService) th.App.srv.SharedChannelSyncHandler(&model.WebSocketEvent{}) - assert.Empty(t, mockService.notifications) + assert.Empty(t, mockService.channelNotifications) }) t.Run("sync service active and broadcast envelope has ineligible event, it does nothing", func(t *testing.T) { @@ -30,13 +30,13 @@ func TestServerSyncSharedChannelHandler(t *testing.T) { mockService := NewMockSharedChannelService(nil) mockService.active = true - th.App.srv.sharedChannelService = mockService + th.App.srv.SetSharedChannelSyncService(mockService) channel := th.CreateChannel(th.BasicTeam, WithShared(true)) websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_ADDED_TO_TEAM, model.NewId(), channel.Id, "", nil) th.App.srv.SharedChannelSyncHandler(websocketEvent) - assert.Empty(t, mockService.notifications) + assert.Empty(t, mockService.channelNotifications) }) t.Run("sync service active and broadcast envelope has eligible event but channel does not exist, it does nothing", func(t *testing.T) { @@ -45,12 +45,12 @@ func TestServerSyncSharedChannelHandler(t *testing.T) { mockService := NewMockSharedChannelService(nil) mockService.active = true - th.App.srv.sharedChannelService = mockService + th.App.srv.SetSharedChannelSyncService(mockService) websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POSTED, model.NewId(), model.NewId(), "", nil) th.App.srv.SharedChannelSyncHandler(websocketEvent) - assert.Empty(t, mockService.notifications) + assert.Empty(t, mockService.channelNotifications) }) t.Run("sync service active when received eligible event, it triggers a shared channel content sync", func(t *testing.T) { @@ -59,13 +59,13 @@ func TestServerSyncSharedChannelHandler(t *testing.T) { mockService := NewMockSharedChannelService(nil) mockService.active = true - th.App.srv.sharedChannelService = mockService + th.App.srv.SetSharedChannelSyncService(mockService) channel := th.CreateChannel(th.BasicTeam, WithShared(true)) websocketEvent := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POSTED, model.NewId(), channel.Id, "", nil) th.App.srv.SharedChannelSyncHandler(websocketEvent) - assert.Len(t, mockService.notifications, 1) - assert.Equal(t, channel.Id, mockService.notifications[0]) + assert.Len(t, mockService.channelNotifications, 1) + assert.Equal(t, channel.Id, mockService.channelNotifications[0]) }) } diff --git a/app/shared_channel_service_iface.go b/app/shared_channel_service_iface.go index a074884f3b..a6fffe332b 100644 --- a/app/shared_channel_service_iface.go +++ b/app/shared_channel_service_iface.go @@ -13,6 +13,7 @@ type SharedChannelServiceIFace interface { Shutdown() error Start() error NotifyChannelChanged(channelId string) + NotifyUserProfileChanged(userID string) SendChannelInvite(channel *model.Channel, userId string, rc *model.RemoteCluster, options ...sharedchannel.InviteOption) error Active() bool } @@ -26,7 +27,7 @@ func MockOptionSharedChannelServiceWithActive(active bool) MockOptionSharedChann } func NewMockSharedChannelService(service SharedChannelServiceIFace, options ...MockOptionSharedChannelService) *mockSharedChannelService { - mrcs := &mockSharedChannelService{service, true, []string{}, 0} + mrcs := &mockSharedChannelService{service, true, []string{}, []string{}, 0} for _, option := range options { option(mrcs) } @@ -35,13 +36,18 @@ func NewMockSharedChannelService(service SharedChannelServiceIFace, options ...M type mockSharedChannelService struct { SharedChannelServiceIFace - active bool - notifications []string - numInvitations int + active bool + channelNotifications []string + userProfileNotifications []string + numInvitations int } func (mrcs *mockSharedChannelService) NotifyChannelChanged(channelId string) { - mrcs.notifications = append(mrcs.notifications, channelId) + mrcs.channelNotifications = append(mrcs.channelNotifications, channelId) +} + +func (mrcs *mockSharedChannelService) NotifyUserProfileChanged(userId string) { + mrcs.userProfileNotifications = append(mrcs.userProfileNotifications, userId) } func (mrcs *mockSharedChannelService) Shutdown() error { diff --git a/app/user.go b/app/user.go index ca9f658522..f2dcaa3530 100644 --- a/app/user.go +++ b/app/user.go @@ -981,6 +981,7 @@ func (a *App) SetProfileImageFromFile(userID string, file io.Reader) *model.AppE mlog.Warn("Error with updating last picture update", mlog.Err(err)) } a.invalidateUserCacheAndPublish(userID) + a.onUserProfileChange(userID) return nil } @@ -1311,6 +1312,7 @@ func (a *App) UpdateUser(user *model.User, sendNotifications bool) (*model.User, } a.InvalidateCacheForUser(user.Id) + a.onUserProfileChange(user.Id) return userUpdate.New, nil } diff --git a/model/post.go b/model/post.go index c305cd8207..ff5795cb4a 100644 --- a/model/post.go +++ b/model/post.go @@ -241,15 +241,15 @@ type GetPostsSinceOptions struct { SortAscending bool } +type GetPostsSinceForSyncCursor struct { + LastPostUpdateAt int64 + LastPostId string +} + type GetPostsSinceForSyncOptions struct { ChannelId string - Since int64 // inclusive - Until int64 // inclusive - SortDescending bool ExcludeRemoteId string IncludeDeleted bool - Limit int - Offset int } type GetPostsOptions struct { @@ -472,6 +472,14 @@ func (o *Post) IsRemote() bool { return o.RemoteId != nil && *o.RemoteId != "" } +// GetRemoteID safely returns the remoteID or empty string if not remote. +func (o *Post) GetRemoteID() string { + if o.RemoteId != nil { + return *o.RemoteId + } + return "" +} + func (o *Post) IsJoinLeaveMessage() bool { return o.Type == POST_JOIN_LEAVE || o.Type == POST_ADD_REMOVE || diff --git a/model/shared_channel.go b/model/shared_channel.go index 1cffbe755b..e3643812e6 100644 --- a/model/shared_channel.go +++ b/model/shared_channel.go @@ -112,7 +112,8 @@ type SharedChannelRemote struct { IsInviteAccepted bool `json:"is_invite_accepted"` IsInviteConfirmed bool `json:"is_invite_confirmed"` RemoteId string `json:"remote_id"` - NextSyncAt int64 `json:"next_sync_at"` + LastPostUpdateAt int64 `json:"last_post_update_at"` + LastPostId string `json:"last_post_id"` } func (sc *SharedChannelRemote) ToJson() string { @@ -211,6 +212,12 @@ func (scu *SharedChannelUser) IsValid() *AppError { return nil } +type GetUsersForSyncFilter struct { + CheckProfileImage bool + ChannelID string + Limit uint64 +} + // SharedChannelAttachment stores a lastSyncAt timestamp on behalf of a remote cluster for // each file attachment that has been synchronized. type SharedChannelAttachment struct { diff --git a/services/sharedchannel/attachment.go b/services/sharedchannel/attachment.go index 4bd2a3085d..9382b78456 100644 --- a/services/sharedchannel/attachment.go +++ b/services/sharedchannel/attachment.go @@ -15,23 +15,6 @@ import ( "github.com/mattermost/mattermost-server/v5/shared/mlog" ) -// postToAttachments returns the file attachments for a post that need to be synchronized. -func (scs *Service) postToAttachments(post *model.Post, rc *model.RemoteCluster) ([]*model.FileInfo, error) { - infos := make([]*model.FileInfo, 0) - - fis, err := scs.server.GetStore().FileInfo().GetForPost(post.Id, false, true, true) - if err != nil { - return nil, fmt.Errorf("could not get file info for attachment: %w", err) - } - - for _, fi := range fis { - if scs.shouldSyncAttachment(fi, rc) { - infos = append(infos, fi) - } - } - return infos, nil -} - // postsToAttachments returns the file attachments for a slice of posts that need to be synchronized. func (scs *Service) shouldSyncAttachment(fi *model.FileInfo, rc *model.RemoteCluster) bool { sca, err := scs.server.GetStore().SharedChannel().GetAttachment(fi.Id, rc.RemoteId) diff --git a/services/sharedchannel/getpostssince.go b/services/sharedchannel/getpostssince.go deleted file mode 100644 index 26d36977fd..0000000000 --- a/services/sharedchannel/getpostssince.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See LICENSE.txt for license information. - -package sharedchannel - -import ( - "github.com/mattermost/mattermost-server/v5/model" - "github.com/mattermost/mattermost-server/v5/shared/mlog" -) - -type sinceResult struct { - posts []*model.Post - hasMore bool - nextSince int64 -} - -// getPostsSince fetches posts that need to be synchronized with a remote cluster. -// There is a soft cap on the number of posts that will be synchronized in a single pass (MaxPostsPerSync). -// -// There is a special case where multiple posts have the same UpdateAt value. It is vital that this method -// include all posts within that millisecond so that subsequent calls can use an incremented `since`. If this -// method were to be called repeatedly with the same `since` value the same records would be returned each time -// and the sync would never move forward. -// -// A boolean is also returned to indicate if there are more posts to be synchronized (true) or not (false). -func (scs *Service) getPostsSince(channelId string, rc *model.RemoteCluster, since int64) (sinceResult, error) { - opts := model.GetPostsSinceForSyncOptions{ - ChannelId: channelId, - Since: since, - IncludeDeleted: true, - Limit: MaxPostsPerSync + 1, // ask for 1 more than needed to peek at first post in next batch - } - posts, err := scs.server.GetStore().Post().GetPostsSinceForSync(opts, true) - if err != nil { - return sinceResult{}, err - } - - if len(posts) == 0 { - return sinceResult{nextSince: since}, nil - } - - var hasMore bool - if len(posts) > MaxPostsPerSync { - hasMore = true - peekUpdateAt := posts[len(posts)-1].UpdateAt - posts = posts[:MaxPostsPerSync] // trim the peeked at record - - // If the last post to be synchronized has the same Update value as the first post in the next batch - // then we need to grab the rest of the posts for that millisecond to ensure the next call can have an - // incremented `since`. - if peekUpdateAt == posts[len(posts)-1].UpdateAt { - opts.Since = peekUpdateAt - opts.Until = opts.Since - opts.Limit = 1000 - opts.Offset = countPostsAtMillisecond(posts, peekUpdateAt) - - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "getPostsSince handling updateAt collision", - mlog.String("remote", rc.DisplayName), - mlog.Int64("update_at", peekUpdateAt), - mlog.Int("offset", opts.Offset), - ) - - morePosts, err := scs.server.GetStore().Post().GetPostsSinceForSync(opts, true) - if err != nil { - return sinceResult{}, err - } - posts = append(posts, morePosts...) - } - } - return sinceResult{posts: posts, hasMore: hasMore, nextSince: posts[len(posts)-1].UpdateAt + 1}, nil -} - -func countPostsAtMillisecond(posts []*model.Post, milli int64) int { - // walk backward through the slice until we find a post with UpdateAt that differs from milli. - var count int - for i := len(posts) - 1; i >= 0; i-- { - if posts[i].UpdateAt != milli { - return count - } - count++ - } - return count -} diff --git a/services/sharedchannel/msg.go b/services/sharedchannel/msg.go index b7eb92b05f..a1bdd62792 100644 --- a/services/sharedchannel/msg.go +++ b/services/sharedchannel/msg.go @@ -4,28 +4,29 @@ package sharedchannel import ( - "context" "encoding/json" - "strings" - "time" "github.com/mattermost/mattermost-server/v5/model" - "github.com/mattermost/mattermost-server/v5/services/remotecluster" - "github.com/mattermost/mattermost-server/v5/shared/mlog" ) // syncMsg represents a change in content (post add/edit/delete, reaction add/remove, users). // It is sent to remote clusters as the payload of a `RemoteClusterMsg`. type syncMsg struct { - ChannelId string `json:"channel_id"` - PostId string `json:"post_id"` - Post *model.Post `json:"post"` - Users []*model.User `json:"users"` - Reactions []*model.Reaction `json:"reactions"` - Attachments []*model.FileInfo `json:"-"` + Id string `json:"id"` + ChannelId string `json:"channel_id"` + Users map[string]*model.User `json:"users,omitempty"` + Posts []*model.Post `json:"posts,omitempty"` + Reactions []*model.Reaction `json:"reactions,omitempty"` } -func (sm syncMsg) ToJSON() ([]byte, error) { +func newSyncMsg(channelID string) *syncMsg { + return &syncMsg{ + Id: model.NewId(), + ChannelId: channelID, + } +} + +func (sm *syncMsg) ToJSON() ([]byte, error) { b, err := json.Marshal(sm) if err != nil { return nil, err @@ -33,296 +34,10 @@ func (sm syncMsg) ToJSON() ([]byte, error) { return b, nil } -func (sm syncMsg) String() string { +func (sm *syncMsg) String() string { json, err := sm.ToJSON() if err != nil { return "" } return string(json) } - -type userCache map[string]struct{} - -func (u userCache) Has(id string) bool { - _, ok := u[id] - return ok -} - -func (u userCache) Add(id string) { - u[id] = struct{}{} -} - -// postsToSyncMessages takes a slice of posts and converts to a `RemoteClusterMsg` which can be -// sent to a remote cluster. -func (scs *Service) postsToSyncMessages(posts []*model.Post, channelID string, rc *model.RemoteCluster, nextSyncAt int64) ([]syncMsg, error) { - syncMessages := make([]syncMsg, 0, len(posts)) - - var teamID string - uCache := make(userCache) - - for _, p := range posts { - if p.IsSystemMessage() { // don't sync system messages - continue - } - - // lookup team id once - if teamID == "" { - sc, err := scs.server.GetStore().SharedChannel().Get(p.ChannelId) - if err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Could not get shared channel for post", - mlog.String("post_id", p.Id), - mlog.Err(err), - ) - continue - } - teamID = sc.TeamId - } - - // any reactions originating from the remote cluster are filtered out - reactions, err := scs.server.GetStore().Reaction().GetForPostSince(p.Id, nextSyncAt, rc.RemoteId, true) - if err != nil { - return nil, err - } - - postSync := p - - // Don't resend an existing post where only the reactions changed. - // Posts we must send: - // - new posts (EditAt == 0) - // - edited posts (EditAt >= nextSyncAt) - // - deleted posts (DeleteAt > 0) - if p.EditAt > 0 && p.EditAt < nextSyncAt && p.DeleteAt == 0 { - postSync = nil - } - - // Don't send a deleted post if it is just the original copy from an edit. - if p.DeleteAt > 0 && p.OriginalId != "" { - postSync = nil - } - - // don't sync a post back to the remote it came from. - if p.RemoteId != nil && *p.RemoteId == rc.RemoteId { - postSync = nil - } - - var attachments []*model.FileInfo - if postSync != nil { - // parse out all permalinks in the message. - postSync.Message = scs.processPermalinkToRemote(postSync) - - // get any file attachments - attachments, err = scs.postToAttachments(postSync, rc) - if err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Could not fetch attachments for post", - mlog.String("post_id", postSync.Id), - mlog.Err(err), - ) - } - } - - // any users originating from the remote cluster are filtered out - users := scs.usersForPost(postSync, reactions, channelID, teamID, rc, uCache) - - // if everything was filtered out then don't send an empty message. - if postSync == nil && len(reactions) == 0 && len(users) == 0 { - continue - } - - sm := syncMsg{ - ChannelId: p.ChannelId, - PostId: p.Id, - Post: postSync, - Users: users, - Reactions: reactions, - Attachments: attachments, - } - syncMessages = append(syncMessages, sm) - } - return syncMessages, nil -} - -// usersForPost provides a list of Users associated with the post that need to be synchronized. -// The user cache ensures the same user is not synchronized redundantly if they appear in multiple -// posts for this sync batch. -func (scs *Service) usersForPost(post *model.Post, reactions []*model.Reaction, channelID string, teamID string, rc *model.RemoteCluster, uCache userCache) []*model.User { - userIds := make([]string, 0) - var mentionMap model.UserMentionMap - - if post != nil && !uCache.Has(post.UserId) { - userIds = append(userIds, post.UserId) - uCache.Add(post.UserId) - } - - for _, r := range reactions { - if !uCache.Has(r.UserId) { - userIds = append(userIds, r.UserId) - uCache.Add(r.UserId) - } - } - - // get mentions and userids for each mention - if post != nil { - mentionMap = scs.app.MentionsToTeamMembers(post.Message, teamID) - for mention, id := range mentionMap { - if !uCache.Has(id) { - userIds = append(userIds, id) - uCache.Add(id) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Found mention", - mlog.String("mention", mention), - mlog.String("user_id", id), - ) - } - } - } - - users := make([]*model.User, 0) - - for _, id := range userIds { - user, err := scs.server.GetStore().User().Get(context.Background(), id) - if err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error checking if user should sync", - mlog.String("user_id", id), - mlog.Err(err), - ) - continue - } - - sync, syncImage, err2 := scs.shouldUserSync(user, channelID, rc) - if err2 != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Could not find user for post", - mlog.String("user_id", id), - mlog.Err(err2), - ) - continue - } - - if sync { - users = append(users, sanitizeUserForSync(user)) - } - - if syncImage { - scs.syncProfileImage(user, channelID, rc) - } - - // if this was a mention then put the real username in place of the username+remotename, but only - // when sending to the remote that the user belongs to. - if user.RemoteId != nil && *user.RemoteId == rc.RemoteId { - fixMention(post, mentionMap, user) - } - } - return users -} - -// fixMention replaces any mentions in a post for the user with the user's real username. -func fixMention(post *model.Post, mentionMap model.UserMentionMap, user *model.User) { - if post == nil || len(mentionMap) == 0 { - return - } - - realUsername, ok := user.GetProp(KeyRemoteUsername) - if !ok { - return - } - - // there may be more than one mention for each user so we have to walk the whole map. - for mention, id := range mentionMap { - if id == user.Id && strings.Contains(mention, ":") { - post.Message = strings.ReplaceAll(post.Message, "@"+mention, "@"+realUsername) - } - } -} - -func sanitizeUserForSync(user *model.User) *model.User { - user.Password = model.NewId() - user.AuthData = nil - user.AuthService = "" - user.Roles = "system_user" - user.AllowMarketing = false - user.NotifyProps = model.StringMap{} - user.LastPasswordUpdate = 0 - user.LastPictureUpdate = 0 - user.FailedAttempts = 0 - user.MfaActive = false - user.MfaSecret = "" - - return user -} - -// shouldUserSync determines if a user needs to be synchronized. -// User should be synchronized if it has no entry in the SharedChannelUsers table for the specified channel, -// or there is an entry but the LastSyncAt is less than user.UpdateAt -func (scs *Service) shouldUserSync(user *model.User, channelID string, rc *model.RemoteCluster) (sync bool, syncImage bool, err error) { - // don't sync users with the remote they originated from. - if user.RemoteId != nil && *user.RemoteId == rc.RemoteId { - return false, false, nil - } - - scu, err := scs.server.GetStore().SharedChannel().GetUser(user.Id, channelID, rc.RemoteId) - if err != nil { - if _, ok := err.(errNotFound); !ok { - return false, false, err - } - - // user not in the SharedChannelUsers table, so we must add them. - scu = &model.SharedChannelUser{ - UserId: user.Id, - RemoteId: rc.RemoteId, - ChannelId: channelID, - } - if _, err = scs.server.GetStore().SharedChannel().SaveUser(scu); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error adding user to shared channel users", - mlog.String("remote_id", rc.RemoteId), - mlog.String("user_id", user.Id), - mlog.String("channel_id", user.Id), - mlog.Err(err), - ) - } - return true, true, nil - } - - return user.UpdateAt > scu.LastSyncAt, user.LastPictureUpdate > scu.LastSyncAt, nil -} - -func (scs *Service) syncProfileImage(user *model.User, channelID string, rc *model.RemoteCluster) { - rcs := scs.server.GetRemoteClusterService() - if rcs == nil { - return - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - rcs.SendProfileImage(ctx, user.Id, rc, scs.app, func(userId string, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { - if resp.IsSuccess() { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Users profile image synchronized", - mlog.String("remote_id", rc.RemoteId), - mlog.String("user_id", user.Id), - ) - - scu, err := scs.server.GetStore().SharedChannel().GetUser(user.Id, channelID, rc.RemoteId) - if err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error fetching shared channel user while updating users LastSyncTime after profile image update", - mlog.String("remote_id", rc.RemoteId), - mlog.String("user_id", user.Id), - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - } - - if err = scs.server.GetStore().SharedChannel().UpdateUserLastSyncAt(scu.Id, model.GetMillis()); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error updating users LastSyncTime after profile image update", - mlog.String("remote_id", rc.RemoteId), - mlog.String("user_id", user.Id), - mlog.Err(err), - ) - } - return - } - - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error synchronizing users profile image", - mlog.String("remote_id", rc.RemoteId), - mlog.String("user_id", user.Id), - mlog.String("Err", resp.Err), - ) - }) -} diff --git a/services/sharedchannel/response.go b/services/sharedchannel/response.go index fc85ae2cfa..61a9fc22f5 100644 --- a/services/sharedchannel/response.go +++ b/services/sharedchannel/response.go @@ -4,7 +4,13 @@ package sharedchannel type SyncResponse struct { - LastSyncAt int64 `json:"last_sync_at"` - PostErrors []string `json:"post_errors"` - UsersSyncd []string `json:"users_syncd"` + UsersLastUpdateAt int64 `json:"users_last_update_at"` + UserErrors []string `json:"user_errors"` + UsersSyncd []string `json:"users_syncd"` + + PostsLastUpdateAt int64 `json:"posts_last_update_at"` + PostErrors []string `json:"post_errors"` + + ReactionsLastUpdateAt int64 `json:"reactions_last_update_at"` + ReactionErrors []string `json:"reaction_errors"` } diff --git a/services/sharedchannel/service.go b/services/sharedchannel/service.go index 4206f3d3c4..bb43cdeafe 100644 --- a/services/sharedchannel/service.go +++ b/services/sharedchannel/service.go @@ -24,9 +24,11 @@ const ( TopicUploadCreate = "sharedchannel_upload" MaxRetries = 3 MaxPostsPerSync = 12 // a bit more than one typical screenfull of posts + MaxUsersPerSync = 25 NotifyRemoteOfflineThreshold = time.Second * 10 NotifyMinimumDelay = time.Second * 2 MaxUpsertRetries = 25 + ProfileImageSyncTimeout = time.Second * 5 KeyRemoteUsername = "RemoteUsername" KeyRemoteEmail = "RemoteEmail" ) diff --git a/services/sharedchannel/sync_recv.go b/services/sharedchannel/sync_recv.go index d8d97bc71f..a4bf301e7b 100644 --- a/services/sharedchannel/sync_recv.go +++ b/services/sharedchannel/sync_recv.go @@ -33,136 +33,131 @@ func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.R ) } - var syncMessages []syncMsg + var sm syncMsg - if err := json.Unmarshal(msg.Payload, &syncMessages); err != nil { + if err := json.Unmarshal(msg.Payload, &sm); err != nil { return fmt.Errorf("invalid sync message: %w", err) } - - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Batch of sync messages received", - mlog.String("remote", rc.DisplayName), - mlog.Int("sync_msg_count", len(syncMessages)), - ) - - return scs.processSyncMessages(syncMessages, rc, response) + return scs.processSyncMessage(&sm, rc, response) } -func (scs *Service) processSyncMessages(syncMessages []syncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { +func (scs *Service) processSyncMessage(syncMsg *syncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error { var channel *model.Channel var team *model.Team - postErrors := make([]string, 0) - usersSyncd := make([]string, 0) - var lastSyncAt int64 var err error + syncResp := SyncResponse{ + UserErrors: make([]string, 0), + UsersSyncd: make([]string, 0), + PostErrors: make([]string, 0), + ReactionErrors: make([]string, 0), + } - for _, sm := range syncMessages { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Sync msg received", - mlog.String("post_id", sm.PostId), - mlog.String("channel_id", sm.ChannelId), - mlog.Int("reaction_count", len(sm.Reactions)), - mlog.Int("user_count", len(sm.Users)), - mlog.Bool("has_post", sm.Post != nil), - ) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Sync msg received", + mlog.String("remote", rc.Name), + mlog.String("channel_id", syncMsg.ChannelId), + mlog.Int("user_count", len(syncMsg.Users)), + mlog.Int("post_count", len(syncMsg.Posts)), + mlog.Int("reaction_count", len(syncMsg.Reactions)), + ) - if channel == nil { - if channel, err = scs.server.GetStore().Channel().Get(sm.ChannelId, true); err != nil { - // if the channel doesn't exist then none of these sync messages are going to work. - return fmt.Errorf("channel not found processing sync messages: %w", err) - } - } - - // add/update users before posts - for _, user := range sm.Users { - if userSaved, err := scs.upsertSyncUser(user, channel, rc); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync user", - mlog.String("post_id", sm.PostId), - mlog.String("channel_id", sm.ChannelId), - mlog.String("user_id", user.Id), - mlog.Err(err)) - } else { - usersSyncd = append(usersSyncd, userSaved.Id) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "User upserted via sync", - mlog.String("post_id", sm.PostId), - mlog.String("channel_id", sm.ChannelId), - mlog.String("user_id", user.Id), - ) - } - } - - if sm.Post != nil { - if sm.ChannelId != sm.Post.ChannelId { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "ChannelId mismatch", - mlog.String("sm.ChannelId", sm.ChannelId), - mlog.String("sm.Post.ChannelId", sm.Post.ChannelId), - mlog.String("PostId", sm.Post.Id), - ) - postErrors = append(postErrors, sm.Post.Id) - continue - } - - if channel.Type != model.CHANNEL_DIRECT && team == nil { - var err2 error - team, err2 = scs.server.GetStore().Channel().GetTeamForChannel(sm.ChannelId) - if err2 != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error getting Team for Channel", - mlog.String("ChannelId", sm.Post.ChannelId), - mlog.String("PostId", sm.Post.Id), - mlog.Err(err2), - ) - postErrors = append(postErrors, sm.Post.Id) - continue - } - } - - // process perma-links for remote - if team != nil { - sm.Post.Message = scs.processPermalinkFromRemote(sm.Post, team) - } - - // add/update post - rpost, err := scs.upsertSyncPost(sm.Post, channel, rc) - if err != nil { - postErrors = append(postErrors, sm.Post.Id) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync post", - mlog.String("post_id", sm.Post.Id), - mlog.String("channel_id", sm.Post.ChannelId), - mlog.Err(err), - ) - } else if lastSyncAt < rpost.UpdateAt { - lastSyncAt = rpost.UpdateAt - } - } - - // add/remove reactions - for _, reaction := range sm.Reactions { - if _, err := scs.upsertSyncReaction(reaction, rc); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync reaction", - mlog.String("user_id", reaction.UserId), - mlog.String("post_id", reaction.PostId), - mlog.String("emoji", reaction.EmojiName), - mlog.Int64("delete_at", reaction.DeleteAt), - mlog.Err(err), - ) - } else { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Reaction upserted via sync", - mlog.String("user_id", reaction.UserId), - mlog.String("post_id", reaction.PostId), - mlog.String("emoji", reaction.EmojiName), - mlog.Int64("delete_at", reaction.DeleteAt), - ) - - if lastSyncAt < reaction.UpdateAt { - lastSyncAt = reaction.UpdateAt - } + if channel, err = scs.server.GetStore().Channel().Get(syncMsg.ChannelId, true); err != nil { + // if the channel doesn't exist then none of these sync items are going to work. + return fmt.Errorf("channel not found processing sync message: %w", err) + } + + // add/update users before posts + for _, user := range syncMsg.Users { + if userSaved, err := scs.upsertSyncUser(user, channel, rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync user", + mlog.String("remote", rc.Name), + mlog.String("channel_id", syncMsg.ChannelId), + mlog.String("user_id", user.Id), + mlog.Err(err)) + } else { + syncResp.UsersSyncd = append(syncResp.UsersSyncd, userSaved.Id) + if syncResp.UsersLastUpdateAt < user.UpdateAt { + syncResp.UsersLastUpdateAt = user.UpdateAt } + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "User upserted via sync", + mlog.String("remote", rc.Name), + mlog.String("channel_id", syncMsg.ChannelId), + mlog.String("user_id", user.Id), + ) } } - syncResp := SyncResponse{ - LastSyncAt: lastSyncAt, // might be zero - PostErrors: postErrors, // might be empty - UsersSyncd: usersSyncd, // might be empty + for _, post := range syncMsg.Posts { + if syncMsg.ChannelId != post.ChannelId { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "ChannelId mismatch", + mlog.String("remote", rc.Name), + mlog.String("sm.ChannelId", syncMsg.ChannelId), + mlog.String("sm.Post.ChannelId", post.ChannelId), + mlog.String("PostId", post.Id), + ) + syncResp.PostErrors = append(syncResp.PostErrors, post.Id) + continue + } + + if channel.Type != model.CHANNEL_DIRECT && team == nil { + var err2 error + team, err2 = scs.server.GetStore().Channel().GetTeamForChannel(syncMsg.ChannelId) + if err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error getting Team for Channel", + mlog.String("ChannelId", post.ChannelId), + mlog.String("PostId", post.Id), + mlog.String("remote", rc.Name), + mlog.Err(err2), + ) + syncResp.PostErrors = append(syncResp.PostErrors, post.Id) + continue + } + } + + // process perma-links for remote + if team != nil { + post.Message = scs.processPermalinkFromRemote(post, team) + } + + // add/update post + rpost, err := scs.upsertSyncPost(post, channel, rc) + if err != nil { + syncResp.PostErrors = append(syncResp.PostErrors, post.Id) + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + mlog.String("remote", rc.Name), + mlog.Err(err), + ) + } else if syncResp.PostsLastUpdateAt < rpost.UpdateAt { + syncResp.PostsLastUpdateAt = rpost.UpdateAt + } + } + + // add/remove reactions + for _, reaction := range syncMsg.Reactions { + if _, err := scs.upsertSyncReaction(reaction, rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync reaction", + mlog.String("remote", rc.Name), + mlog.String("user_id", reaction.UserId), + mlog.String("post_id", reaction.PostId), + mlog.String("emoji", reaction.EmojiName), + mlog.Int64("delete_at", reaction.DeleteAt), + mlog.Err(err), + ) + } else { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Reaction upserted via sync", + mlog.String("remote", rc.Name), + mlog.String("user_id", reaction.UserId), + mlog.String("post_id", reaction.PostId), + mlog.String("emoji", reaction.EmojiName), + mlog.Int64("delete_at", reaction.DeleteAt), + ) + + if syncResp.ReactionsLastUpdateAt < reaction.UpdateAt { + syncResp.ReactionsLastUpdateAt = reaction.UpdateAt + } + } } response.SetPayload(syncResp) @@ -345,24 +340,30 @@ func (scs *Service) upsertSyncPost(post *model.Post, channel *model.Channel, rc if rpost == nil { // post doesn't exist; create new one rpost, appErr = scs.app.CreatePost(request.EmptyContext(), post, channel, true, true) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Created sync post", - mlog.String("post_id", post.Id), - mlog.String("channel_id", post.ChannelId), - ) + if appErr == nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Created sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } } else if post.DeleteAt > 0 { // delete post rpost, appErr = scs.app.DeletePost(post.Id, post.UserId) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Deleted sync post", - mlog.String("post_id", post.Id), - mlog.String("channel_id", post.ChannelId), - ) + if appErr == nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Deleted sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } } else if post.EditAt > rpost.EditAt || post.Message != rpost.Message { // update post rpost, appErr = scs.app.UpdatePost(request.EmptyContext(), post, false) - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Updated sync post", - mlog.String("post_id", post.Id), - mlog.String("channel_id", post.ChannelId), - ) + if appErr == nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Updated sync post", + mlog.String("post_id", post.Id), + mlog.String("channel_id", post.ChannelId), + ) + } } else { // nothing to update scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Update to sync post ignored", diff --git a/services/sharedchannel/sync_send.go b/services/sharedchannel/sync_send.go index c4c7b7e621..fdcc4bf8b3 100644 --- a/services/sharedchannel/sync_send.go +++ b/services/sharedchannel/sync_send.go @@ -5,9 +5,7 @@ package sharedchannel import ( "context" - "encoding/json" "fmt" - "sync" "time" "github.com/mattermost/mattermost-server/v5/model" @@ -18,25 +16,25 @@ import ( type syncTask struct { id string - channelId string - remoteId string + channelID string + remoteID string AddedAt time.Time retryCount int - retryPost *model.Post + retryMsg *syncMsg schedule time.Time } -func newSyncTask(channelId string, remoteId string, retryPost *model.Post) syncTask { - var postId string - if retryPost != nil { - postId = retryPost.Id +func newSyncTask(channelID string, remoteID string, retryMsg *syncMsg) syncTask { + var retryID string + if retryMsg != nil { + retryID = retryMsg.Id } return syncTask{ - id: channelId + remoteId + postId, // combination of ids to avoid duplicates - channelId: channelId, - remoteId: remoteId, // empty means update all remote clusters - retryPost: retryPost, + id: channelID + remoteID + retryID, // combination of ids to avoid duplicates + channelID: channelID, + remoteID: remoteID, // empty means update all remote clusters + retryMsg: retryMsg, schedule: time.Now(), } } @@ -49,16 +47,52 @@ func (st *syncTask) incRetry() bool { // NotifyChannelChanged is called to indicate that a shared channel has been modified, // thus triggering an update to all remote clusters. -func (scs *Service) NotifyChannelChanged(channelId string) { +func (scs *Service) NotifyChannelChanged(channelID string) { if rcs := scs.server.GetRemoteClusterService(); rcs == nil { return } - task := newSyncTask(channelId, "", nil) + task := newSyncTask(channelID, "", nil) task.schedule = time.Now().Add(NotifyMinimumDelay) scs.addTask(task) } +// NotifyUserProfileChanged is called to indicate that a user belonging to at least one +// shared channel has modified their user profile (name, username, email, custom status, profile image) +func (scs *Service) NotifyUserProfileChanged(userID string) { + if rcs := scs.server.GetRemoteClusterService(); rcs == nil { + return + } + + scusers, err := scs.server.GetStore().SharedChannel().GetUsersForUser(userID) + if err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Failed to fetch shared channel users", + mlog.String("userID", userID), + mlog.Err(err), + ) + return + } + if len(scusers) == 0 { + return + } + + notified := make(map[string]struct{}) + + for _, user := range scusers { + // update every channel + remote combination they belong to. + // Redundant updates (ie. to same remote for multiple channels) will be + // filtered out. + combo := user.ChannelId + user.RemoteId + if _, ok := notified[combo]; ok { + continue + } + notified[combo] = struct{}{} + task := newSyncTask(user.ChannelId, user.RemoteId, nil) + task.schedule = time.Now().Add(NotifyMinimumDelay) + scs.addTask(task) + } +} + // ForceSyncForRemote causes all channels shared with the remote to be synchronized. func (scs *Service) ForceSyncForRemote(rc *model.RemoteCluster) { if rcs := scs.server.GetRemoteClusterService(); rcs == nil { @@ -155,8 +189,8 @@ func (scs *Service) doSync() time.Duration { scs.addTask(task) } else { scs.server.GetLogger().Error("Failed to synchronize shared channel", - mlog.String("channelId", task.channelId), - mlog.String("remoteId", task.remoteId), + mlog.String("channelId", task.channelID), + mlog.String("remoteId", task.remoteID), mlog.Err(err), ) } @@ -204,9 +238,9 @@ func (scs *Service) processTask(task syncTask) error { var err error var remotes []*model.RemoteCluster - if task.remoteId == "" { + if task.remoteID == "" { filter := model.RemoteClusterQueryFilter{ - InChannel: task.channelId, + InChannel: task.channelID, OnlyConfirmed: true, } remotes, err = scs.server.GetStore().RemoteCluster().GetAll(filter) @@ -214,28 +248,27 @@ func (scs *Service) processTask(task syncTask) error { return err } } else { - rc, err := scs.server.GetStore().RemoteCluster().Get(task.remoteId) + rc, err := scs.server.GetStore().RemoteCluster().Get(task.remoteID) if err != nil { return err } if !rc.IsOnline() { - return fmt.Errorf("Failed updating shared channel '%s' for offline remote cluster '%s'", task.channelId, rc.DisplayName) + return fmt.Errorf("Failed updating shared channel '%s' for offline remote cluster '%s'", task.channelID, rc.DisplayName) } remotes = []*model.RemoteCluster{rc} } for _, rc := range remotes { rtask := task - rtask.remoteId = rc.RemoteId - if err := scs.updateForRemote(rtask, rc); err != nil { + rtask.remoteID = rc.RemoteId + if err := scs.syncForRemote(rtask, rc); err != nil { // retry... if rtask.incRetry() { scs.addTask(rtask) } else { scs.server.GetLogger().Error("Failed to synchronize shared channel for remote cluster", - mlog.String("channelId", rtask.channelId), + mlog.String("channelId", rtask.channelID), mlog.String("remote", rc.DisplayName), - mlog.String("remoteId", rtask.remoteId), mlog.Err(err), ) } @@ -244,160 +277,8 @@ func (scs *Service) processTask(task syncTask) error { return nil } -// updateForRemote updates a remote cluster with any new posts/reactions for a specific -// channel. If many changes are found, only the oldest X changes are sent and the channel -// is re-added to the task map. This ensures no channels are starved for updates even if some -// channels are very active. -func (scs *Service) updateForRemote(task syncTask, rc *model.RemoteCluster) error { - rcs := scs.server.GetRemoteClusterService() - if rcs == nil { - return fmt.Errorf("cannot update remote cluster for channel id %s; Remote Cluster Service not enabled", task.channelId) - } - - scr, err := scs.server.GetStore().SharedChannel().GetRemoteByIds(task.channelId, rc.RemoteId) - if err != nil { - return err - } - - var posts []*model.Post - var repeat bool - nextSince := scr.NextSyncAt - - if task.retryPost != nil { - posts = []*model.Post{task.retryPost} - } else { - result, err2 := scs.getPostsSince(task.channelId, rc, scr.NextSyncAt) - if err2 != nil { - return err2 - } - posts = result.posts - repeat = result.hasMore - nextSince = result.nextSince - } - - if len(posts) == 0 { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task found zero posts; skipping sync", - mlog.String("remote", rc.DisplayName), - mlog.String("channel_id", task.channelId), - mlog.Int64("lastSyncAt", scr.NextSyncAt), - mlog.Int64("nextSince", nextSince), - mlog.Bool("repeat", repeat), - ) - return nil - } - - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task found posts to sync", - mlog.String("remote", rc.DisplayName), - mlog.String("channel_id", task.channelId), - mlog.Int64("lastSyncAt", scr.NextSyncAt), - mlog.Int64("nextSince", nextSince), - mlog.Int("count", len(posts)), - mlog.Bool("repeat", repeat), - ) - - if !rc.IsOnline() { - scs.notifyRemoteOffline(posts, rc) - return nil - } - - syncMessages, err := scs.postsToSyncMessages(posts, task.channelId, rc, scr.NextSyncAt) - if err != nil { - return err - } - - if len(syncMessages) == 0 { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "sync task, all messages filtered out; skipping sync", - mlog.String("remote", rc.DisplayName), - mlog.String("channel_id", task.channelId), - mlog.Bool("repeat", repeat), - ) - - // All posts were filtered out, meaning no need to send them. Fast forward SharedChannelRemote's NextSyncAt. - scs.updateNextSyncForRemote(scr.Id, rc, nextSince) - - // if there are more posts eligible to sync then schedule another sync - if repeat { - scs.addTask(newSyncTask(task.channelId, task.remoteId, nil)) - } - return nil - } - - scs.sendAttachments(syncMessages, rc) - - b, err := json.Marshal(syncMessages) - if err != nil { - return err - } - msg := model.NewRemoteClusterMsg(TopicSync, b) - - if scs.server.GetLogger().IsLevelEnabled(mlog.LvlSharedChannelServiceMessagesOutbound) { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceMessagesOutbound, "outbound message", - mlog.String("remote", rc.DisplayName), - mlog.Int64("NextSyncAt", scr.NextSyncAt), - mlog.String("msg", string(b)), - ) - } - - ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) - defer cancel() - - var wg sync.WaitGroup - wg.Add(1) - - err = rcs.SendMsg(ctx, msg, rc, func(msg model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { - defer wg.Done() - if err != nil { - return // this means the response could not be parsed; already logged - } - - var syncResp SyncResponse - if err2 := json.Unmarshal(resp.Payload, &syncResp); err2 != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "invalid sync response after update shared channel", - mlog.String("remote", rc.DisplayName), - mlog.Err(err2), - ) - } - - // Any Post(s) that failed to save on remote side are included in an array of post ids in the Response payload. - // Handle each error by retrying the post a fixed number of times before giving up. - for _, p := range syncResp.PostErrors { - scs.handlePostError(p, task, rc) - } - - // update NextSyncAt for all the users that were synchronized - scs.updateSyncUsers(syncResp.UsersSyncd, task.channelId, rc, nextSince) - }) - - wg.Wait() - - if err == nil { - // Optimistically update SharedChannelRemote's NextSyncAt; if any posts failed they will be retried. - scs.updateNextSyncForRemote(scr.Id, rc, nextSince) - } - - if repeat { - scs.addTask(newSyncTask(task.channelId, task.remoteId, nil)) - } - return err -} - -func (scs *Service) sendAttachments(syncMessages []syncMsg, rc *model.RemoteCluster) { - for _, sm := range syncMessages { - for _, fi := range sm.Attachments { - if err := scs.sendAttachmentForRemote(fi, sm.Post, rc); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error syncing attachment for post", - mlog.String("remote", rc.DisplayName), - mlog.String("post_id", sm.Post.Id), - mlog.String("file_id", fi.Id), - mlog.Err(err), - ) - } - } - } -} - func (scs *Service) handlePostError(postId string, task syncTask, rc *model.RemoteCluster) { - if task.retryPost != nil && task.retryPost.Id == postId { + if task.retryMsg != nil && len(task.retryMsg.Posts) == 1 && task.retryMsg.Posts[0].Id == postId { // this was a retry for specific post that failed previously. Try again if within MaxRetries. if task.incRetry() { scs.addTask(task) @@ -419,7 +300,11 @@ func (scs *Service) handlePostError(postId string, task syncTask, rc *model.Remo ) return } - scs.addTask(newSyncTask(task.channelId, task.remoteId, post)) + + syncMsg := newSyncMsg(task.channelID) + syncMsg.Posts = []*model.Post{post} + + scs.addTask(newSyncTask(task.channelID, task.remoteID, syncMsg)) } // notifyRemoteOffline creates an ephemeral post to the author for any posts created recently to remotes @@ -452,54 +337,22 @@ func (scs *Service) notifyRemoteOffline(posts []*model.Post, rc *model.RemoteClu } } -func (scs *Service) updateNextSyncForRemote(scrId string, rc *model.RemoteCluster, nextSyncAt int64) { - if nextSyncAt == 0 { - return - } - if err := scs.server.GetStore().SharedChannel().UpdateRemoteNextSyncAt(scrId, nextSyncAt); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error updating NextSyncAt for shared channel remote", +func (scs *Service) updateCursorForRemote(scrId string, rc *model.RemoteCluster, cursor model.GetPostsSinceForSyncCursor) { + if err := scs.server.GetStore().SharedChannel().UpdateRemoteCursor(scrId, cursor); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error updating cursor for shared channel remote", mlog.String("remote", rc.DisplayName), mlog.Err(err), ) return } - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "updated NextSyncAt for remote", + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "updated cursor for remote", mlog.String("remote_id", rc.RemoteId), mlog.String("remote", rc.DisplayName), - mlog.Int64("next_update_at", nextSyncAt), + mlog.Int64("last_post_update_at", cursor.LastPostUpdateAt), + mlog.String("last_post_id", cursor.LastPostId), ) } -func (scs *Service) updateSyncUsers(userIds []string, channelID string, rc *model.RemoteCluster, lastSyncAt int64) { - for _, uid := range userIds { - scu, err := scs.server.GetStore().SharedChannel().GetUser(uid, channelID, rc.RemoteId) - if err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error getting user for lastSyncAt update", - mlog.String("remote", rc.DisplayName), - mlog.String("user_id", uid), - mlog.Err(err), - ) - continue - } - - if err := scs.server.GetStore().SharedChannel().UpdateUserLastSyncAt(scu.Id, lastSyncAt); err != nil { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "error updating lastSyncAt for user", - mlog.String("remote", rc.DisplayName), - mlog.String("user_id", uid), - mlog.String("channel_id", channelID), - mlog.Err(err), - ) - } else { - scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "updated lastSyncAt for user", - mlog.String("remote", rc.DisplayName), - mlog.String("user_id", scu.UserId), - mlog.String("channel_id", channelID), - mlog.Int64("last_update_at", lastSyncAt), - ) - } - } -} - func (scs *Service) getUserTranslations(userId string) i18n.TranslateFunc { var locale string user, err := scs.server.GetStore().User().Get(context.Background(), userId) @@ -512,3 +365,72 @@ func (scs *Service) getUserTranslations(userId string) i18n.TranslateFunc { } return i18n.GetUserTranslations(locale) } + +// shouldUserSync determines if a user needs to be synchronized. +// User should be synchronized if it has no entry in the SharedChannelUsers table for the specified channel, +// or there is an entry but the LastSyncAt is less than user.UpdateAt +func (scs *Service) shouldUserSync(user *model.User, channelID string, rc *model.RemoteCluster) (sync bool, syncImage bool, err error) { + // don't sync users with the remote they originated from. + if user.RemoteId != nil && *user.RemoteId == rc.RemoteId { + return false, false, nil + } + + scu, err := scs.server.GetStore().SharedChannel().GetSingleUser(user.Id, channelID, rc.RemoteId) + if err != nil { + if _, ok := err.(errNotFound); !ok { + return false, false, err + } + + // user not in the SharedChannelUsers table, so we must add them. + scu = &model.SharedChannelUser{ + UserId: user.Id, + RemoteId: rc.RemoteId, + ChannelId: channelID, + } + if _, err = scs.server.GetStore().SharedChannel().SaveUser(scu); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error adding user to shared channel users", + mlog.String("remote_id", rc.RemoteId), + mlog.String("user_id", user.Id), + mlog.String("channel_id", user.Id), + mlog.Err(err), + ) + } + return true, true, nil + } + + return user.UpdateAt > scu.LastSyncAt, user.LastPictureUpdate > scu.LastSyncAt, nil +} + +func (scs *Service) syncProfileImage(user *model.User, channelID string, rc *model.RemoteCluster) { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), ProfileImageSyncTimeout) + defer cancel() + + rcs.SendProfileImage(ctx, user.Id, rc, scs.app, func(userId string, rc *model.RemoteCluster, resp *remotecluster.Response, err error) { + if resp.IsSuccess() { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Users profile image synchronized", + mlog.String("remote_id", rc.RemoteId), + mlog.String("user_id", user.Id), + ) + + if err2 := scs.server.GetStore().SharedChannel().UpdateUserLastSyncAt(user.Id, channelID, rc.RemoteId); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error updating users LastSyncTime after profile image update", + mlog.String("remote_id", rc.RemoteId), + mlog.String("user_id", user.Id), + mlog.Err(err2), + ) + } + return + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Error synchronizing users profile image", + mlog.String("remote_id", rc.RemoteId), + mlog.String("user_id", user.Id), + mlog.Err(err), + ) + }) +} diff --git a/services/sharedchannel/sync_send_remote.go b/services/sharedchannel/sync_send_remote.go new file mode 100644 index 0000000000..7cabd69530 --- /dev/null +++ b/services/sharedchannel/sync_send_remote.go @@ -0,0 +1,533 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sharedchannel + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/wiggin77/merror" + + "github.com/mattermost/mattermost-server/v5/model" + "github.com/mattermost/mattermost-server/v5/services/remotecluster" + "github.com/mattermost/mattermost-server/v5/shared/mlog" +) + +type sendSyncMsgResultFunc func(syncResp SyncResponse, err error) + +type attachment struct { + fi *model.FileInfo + post *model.Post +} + +type syncData struct { + task syncTask + rc *model.RemoteCluster + scr *model.SharedChannelRemote + + users map[string]*model.User + profileImages map[string]*model.User + posts []*model.Post + reactions []*model.Reaction + attachments []attachment + + resultRepeat bool + resultNextCursor model.GetPostsSinceForSyncCursor +} + +func newSyncData(task syncTask, rc *model.RemoteCluster, scr *model.SharedChannelRemote) *syncData { + return &syncData{ + task: task, + rc: rc, + scr: scr, + users: make(map[string]*model.User), + profileImages: make(map[string]*model.User), + resultNextCursor: model.GetPostsSinceForSyncCursor{LastPostUpdateAt: scr.LastPostUpdateAt, LastPostId: scr.LastPostId}, + } +} + +func (sd *syncData) isEmpty() bool { + return len(sd.users) == 0 && len(sd.profileImages) == 0 && len(sd.posts) == 0 && len(sd.reactions) == 0 && len(sd.attachments) == 0 +} + +func (sd *syncData) isCursorChanged() bool { + return sd.scr.LastPostUpdateAt != sd.resultNextCursor.LastPostUpdateAt || sd.scr.LastPostId != sd.resultNextCursor.LastPostId +} + +// syncForRemote updates a remote cluster with any new posts/reactions for a specific +// channel. If many changes are found, only the oldest X changes are sent and the channel +// is re-added to the task map. This ensures no channels are starved for updates even if some +// channels are very active. +// Returning an error forces a retry on the task. +func (scs *Service) syncForRemote(task syncTask, rc *model.RemoteCluster) error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return fmt.Errorf("cannot update remote cluster %s for channel id %s; Remote Cluster Service not enabled", rc.Name, task.channelID) + } + + scr, err := scs.server.GetStore().SharedChannel().GetRemoteByIds(task.channelID, rc.RemoteId) + if err != nil { + return err + } + + // if this is retrying a failed msg, just send it again. + if task.retryMsg != nil { + sd := newSyncData(task, rc, scr) + sd.users = task.retryMsg.Users + sd.posts = task.retryMsg.Posts + sd.reactions = task.retryMsg.Reactions + return scs.sendSyncData(sd) + } + + sd := newSyncData(task, rc, scr) + + // schedule another sync if the repeat flag is set at some point. + defer func(rpt *bool) { + if *rpt { + scs.addTask(newSyncTask(task.channelID, task.remoteID, nil)) + } + }(&sd.resultRepeat) + + // fetch new posts or retry post. + if err := scs.fetchPostsForSync(sd); err != nil { + return fmt.Errorf("cannot fetch posts for sync %v: %w", sd, err) + } + + if !rc.IsOnline() { + if len(sd.posts) != 0 { + scs.notifyRemoteOffline(sd.posts, rc) + } + sd.resultRepeat = false + return nil + } + + // fetch users that have updated their user profile or image. + if err := scs.fetchUsersForSync(sd); err != nil { + return fmt.Errorf("cannot fetch users for sync %v: %w", sd, err) + } + + // fetch reactions for posts + if err := scs.fetchReactionsForSync(sd); err != nil { + return fmt.Errorf("cannot fetch reactions for sync %v: %w", sd, err) + } + + // fetch users associated with posts & reactions + if err := scs.fetchPostUsersForSync(sd); err != nil { + return fmt.Errorf("cannot fetch post users for sync %v: %w", sd, err) + } + + // filter out any posts that don't need to be sent. + scs.filterPostsForSync(sd) + + // fetch attachments for posts + if err := scs.fetchPostAttachmentsForSync(sd); err != nil { + return fmt.Errorf("cannot fetch post attachments for sync %v: %w", sd, err) + } + + if sd.isEmpty() { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Not sending sync data; everything filtered out", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", task.channelID), + mlog.Bool("repeat", sd.resultRepeat), + ) + if sd.isCursorChanged() { + scs.updateCursorForRemote(sd.scr.Id, sd.rc, sd.resultNextCursor) + } + return nil + } + + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceDebug, "Sending sync data", + mlog.String("remote", rc.DisplayName), + mlog.String("channel_id", task.channelID), + mlog.Bool("repeat", sd.resultRepeat), + mlog.Int("users", len(sd.users)), + mlog.Int("images", len(sd.profileImages)), + mlog.Int("posts", len(sd.posts)), + mlog.Int("reactions", len(sd.reactions)), + mlog.Int("attachments", len(sd.attachments)), + ) + + return scs.sendSyncData(sd) +} + +// fetchUsersForSync populates the sync data with any channel users who updated their user profile +// since the last sync. +func (scs *Service) fetchUsersForSync(sd *syncData) error { + filter := model.GetUsersForSyncFilter{ + ChannelID: sd.task.channelID, + Limit: MaxUsersPerSync, + } + users, err := scs.server.GetStore().SharedChannel().GetUsersForSync(filter) + if err != nil { + return err + } + + for _, u := range users { + if u.GetRemoteID() != sd.rc.RemoteId { + sd.users[u.Id] = u + } + } + + filter.CheckProfileImage = true + usersImage, err := scs.server.GetStore().SharedChannel().GetUsersForSync(filter) + if err != nil { + return err + } + + for _, u := range usersImage { + if u.GetRemoteID() != sd.rc.RemoteId { + sd.profileImages[u.Id] = u + } + } + return nil +} + +// fetchPostsForSync populates the sync data with any new posts since the last sync. +func (scs *Service) fetchPostsForSync(sd *syncData) error { + options := model.GetPostsSinceForSyncOptions{ + ChannelId: sd.task.channelID, + IncludeDeleted: true, + } + cursor := model.GetPostsSinceForSyncCursor{ + LastPostUpdateAt: sd.scr.LastPostUpdateAt, + LastPostId: sd.scr.LastPostId, + } + + posts, nextCursor, err := scs.server.GetStore().Post().GetPostsSinceForSync(options, cursor, MaxPostsPerSync) + if err != nil { + return fmt.Errorf("could not fetch new posts for sync: %w", err) + } + + // Append the posts individually, checking for root posts that might appear later in the list. + // This is due to the UpdateAt collision handling algorithm where the order of posts is not based + // on UpdateAt or CreateAt when the posts have the same UpdateAt value. Here we are guarding + // against a root post with the same UpdateAt (and probably the same CreateAt) appearing later + // in the list and must be sync'd before the child post. This is and edge case that likely only + // happens during load testing or bulk imports. + for _, p := range posts { + if p.RootId != "" { + root, err := scs.server.GetStore().Post().GetSingle(p.RootId, true) + if err == nil { + if (root.CreateAt >= cursor.LastPostUpdateAt || root.UpdateAt >= cursor.LastPostUpdateAt) && !containsPost(sd.posts, root) { + sd.posts = append(sd.posts, root) + } + } + } + sd.posts = append(sd.posts, p) + } + + sd.resultNextCursor = nextCursor + sd.resultRepeat = len(posts) == MaxPostsPerSync + return nil +} + +func containsPost(posts []*model.Post, post *model.Post) bool { + for _, p := range posts { + if p.Id == post.Id { + return true + } + } + return false +} + +// fetchReactionsForSync populates the sync data with any new reactions since the last sync. +func (scs *Service) fetchReactionsForSync(sd *syncData) error { + merr := merror.New() + for _, post := range sd.posts { + // any reactions originating from the remote cluster are filtered out + reactions, err := scs.server.GetStore().Reaction().GetForPostSince(post.Id, sd.scr.LastPostUpdateAt, sd.rc.RemoteId, true) + if err != nil { + merr.Append(fmt.Errorf("could not get reactions for post %s: %w", post.Id, err)) + continue + } + sd.reactions = append(sd.reactions, reactions...) + } + return merr.ErrorOrNil() +} + +// fetchPostUsersForSync populates the sync data with all users associated with posts. +func (scs *Service) fetchPostUsersForSync(sd *syncData) error { + sc, err := scs.server.GetStore().SharedChannel().Get(sd.task.channelID) + if err != nil { + return fmt.Errorf("cannot determine teamID: %w", err) + } + + type p2mm struct { + post *model.Post + mentionMap model.UserMentionMap + } + + userIDs := make(map[string]p2mm) + + for _, reaction := range sd.reactions { + userIDs[reaction.UserId] = p2mm{} + } + + for _, post := range sd.posts { + // add author + userIDs[post.UserId] = p2mm{} + + // get mentions and users for each mention + mentionMap := scs.app.MentionsToTeamMembers(post.Message, sc.TeamId) + for _, userID := range mentionMap { + userIDs[userID] = p2mm{ + post: post, + mentionMap: mentionMap, + } + } + } + + merr := merror.New() + + for userID, v := range userIDs { + user, err := scs.server.GetStore().User().Get(context.Background(), userID) + if err != nil { + merr.Append(fmt.Errorf("could not get user %s: %w", userID, err)) + continue + } + + sync, syncImage, err2 := scs.shouldUserSync(user, sd.task.channelID, sd.rc) + if err2 != nil { + merr.Append(fmt.Errorf("could not check should sync user %s: %w", userID, err)) + continue + } + + if sync { + sd.users[user.Id] = sanitizeUserForSync(user) + } + + if syncImage { + sd.profileImages[user.Id] = sanitizeUserForSync(user) + } + + // if this was a mention then put the real username in place of the username+remotename, but only + // when sending to the remote that the user belongs to. + if v.post != nil && user.RemoteId != nil && *user.RemoteId == sd.rc.RemoteId { + fixMention(v.post, v.mentionMap, user) + } + } + return merr.ErrorOrNil() +} + +// fetchPostAttachmentsForSync populates the sync data with any file attachments for new posts. +func (scs *Service) fetchPostAttachmentsForSync(sd *syncData) error { + merr := merror.New() + for _, post := range sd.posts { + fis, err := scs.server.GetStore().FileInfo().GetForPost(post.Id, false, true, true) + if err != nil { + merr.Append(fmt.Errorf("could not get file attachment info for post %s: %w", post.Id, err)) + continue + } + + for _, fi := range fis { + if scs.shouldSyncAttachment(fi, sd.rc) { + sd.attachments = append(sd.attachments, attachment{fi: fi, post: post}) + } + } + } + return merr.ErrorOrNil() +} + +// filterPostsforSync removes any posts that do not need to sync. +func (scs *Service) filterPostsForSync(sd *syncData) { + filtered := make([]*model.Post, 0, len(sd.posts)) + + for _, p := range sd.posts { + // Don't resend an existing post where only the reactions changed. + // Posts we must send: + // - new posts (EditAt == 0) + // - edited posts (EditAt >= LastPostUpdateAt) + // - deleted posts (DeleteAt > 0) + if p.EditAt > 0 && p.EditAt < sd.scr.LastPostUpdateAt && p.DeleteAt == 0 { + continue + } + + // Don't send a deleted post if it is just the original copy from an edit. + if p.DeleteAt > 0 && p.OriginalId != "" { + continue + } + + // don't sync a post back to the remote it came from. + if p.GetRemoteID() == sd.rc.RemoteId { + continue + } + + // parse out all permalinks in the message. + p.Message = scs.processPermalinkToRemote(p) + + filtered = append(filtered, p) + } + sd.posts = filtered +} + +// sendSyncData sends all the collected users, posts, reactions, images, and attachments to the +// remote cluster. +// The order of items sent is important: users -> attachments -> posts -> reactions -> profile images +func (scs *Service) sendSyncData(sd *syncData) error { + merr := merror.New() + + // send users + if len(sd.users) != 0 { + if err := scs.sendUserSyncData(sd); err != nil { + merr.Append(fmt.Errorf("cannot send user sync data: %w", err)) + } + } + + // send attachments + if len(sd.attachments) != 0 { + scs.sendAttachmentSyncData(sd) + } + + // send posts + if len(sd.posts) != 0 { + if err := scs.sendPostSyncData(sd); err != nil { + merr.Append(fmt.Errorf("cannot send post sync data: %w", err)) + } + } else if sd.isCursorChanged() { + scs.updateCursorForRemote(sd.scr.Id, sd.rc, sd.resultNextCursor) + } + + // send reactions + if len(sd.reactions) != 0 { + if err := scs.sendReactionSyncData(sd); err != nil { + merr.Append(fmt.Errorf("cannot send reaction sync data: %w", err)) + } + } + + // send user profile images + if len(sd.profileImages) != 0 { + scs.sendProfileImageSyncData(sd) + } + + return merr.ErrorOrNil() +} + +// sendUserSyncData sends the collected user updates to the remote cluster. +func (scs *Service) sendUserSyncData(sd *syncData) error { + msg := newSyncMsg(sd.task.channelID) + msg.Users = sd.users + + err := scs.sendSyncMsgToRemote(msg, sd.rc, func(syncResp SyncResponse, errResp error) { + for _, userID := range syncResp.UsersSyncd { + if err := scs.server.GetStore().SharedChannel().UpdateUserLastSyncAt(userID, sd.task.channelID, sd.rc.RemoteId); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Cannot update shared channel user LastSyncAt", + mlog.String("user_id", userID), + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Err(err), + ) + } + } + if len(syncResp.UserErrors) != 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Response indicates error for user(s) sync", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Any("users", syncResp.UserErrors), + ) + } + }) + return err +} + +// sendAttachmentSyncData sends the collected post updates to the remote cluster. +func (scs *Service) sendAttachmentSyncData(sd *syncData) { + for _, a := range sd.attachments { + if err := scs.sendAttachmentForRemote(a.fi, a.post, sd.rc); err != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Cannot sync post attachment", + mlog.String("post_id", a.post.Id), + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Err(err), + ) + } + // updating SharedChannelAttachments with LastSyncAt is already done. + } +} + +// sendPostSyncData sends the collected post updates to the remote cluster. +func (scs *Service) sendPostSyncData(sd *syncData) error { + msg := newSyncMsg(sd.task.channelID) + msg.Posts = sd.posts + + return scs.sendSyncMsgToRemote(msg, sd.rc, func(syncResp SyncResponse, errResp error) { + if len(syncResp.PostErrors) != 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Response indicates error for post(s) sync", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Any("posts", syncResp.PostErrors), + ) + + for _, postID := range syncResp.PostErrors { + scs.handlePostError(postID, sd.task, sd.rc) + } + } + scs.updateCursorForRemote(sd.scr.Id, sd.rc, sd.resultNextCursor) + }) +} + +// sendReactionSyncData sends the collected reaction updates to the remote cluster. +func (scs *Service) sendReactionSyncData(sd *syncData) error { + msg := newSyncMsg(sd.task.channelID) + msg.Reactions = sd.reactions + + return scs.sendSyncMsgToRemote(msg, sd.rc, func(syncResp SyncResponse, errResp error) { + if len(syncResp.ReactionErrors) != 0 { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Response indicates error for reactions(s) sync", + mlog.String("channel_id", sd.task.channelID), + mlog.String("remote_id", sd.rc.RemoteId), + mlog.Any("reaction_posts", syncResp.ReactionErrors), + ) + } + }) +} + +// sendProfileImageSyncData sends the collected user profile image updates to the remote cluster. +func (scs *Service) sendProfileImageSyncData(sd *syncData) { + for _, user := range sd.profileImages { + scs.syncProfileImage(user, sd.task.channelID, sd.rc) + } +} + +// sendSyncMsgToRemote synchronously sends the sync message to the remote cluster. +func (scs *Service) sendSyncMsgToRemote(msg *syncMsg, rc *model.RemoteCluster, f sendSyncMsgResultFunc) error { + rcs := scs.server.GetRemoteClusterService() + if rcs == nil { + return fmt.Errorf("cannot update remote cluster %s for channel id %s; Remote Cluster Service not enabled", rc.Name, msg.ChannelId) + } + + b, err := json.Marshal(msg) + if err != nil { + return err + } + rcMsg := model.NewRemoteClusterMsg(TopicSync, b) + + ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + + err = rcs.SendMsg(ctx, rcMsg, rc, func(rcMsg model.RemoteClusterMsg, rc *model.RemoteCluster, rcResp *remotecluster.Response, errResp error) { + defer wg.Done() + + var syncResp SyncResponse + if err2 := json.Unmarshal(rcResp.Payload, &syncResp); err2 != nil { + scs.server.GetLogger().Log(mlog.LvlSharedChannelServiceError, "Invalid sync msg response from remote cluster", + mlog.String("remote", rc.Name), + mlog.String("channel_id", msg.ChannelId), + mlog.Err(err2), + ) + return + } + + if f != nil { + f(syncResp, errResp) + } + }) + + wg.Wait() + return err +} diff --git a/services/sharedchannel/util.go b/services/sharedchannel/util.go index 93d019d900..77a0d4eff9 100644 --- a/services/sharedchannel/util.go +++ b/services/sharedchannel/util.go @@ -10,6 +10,41 @@ import ( "github.com/mattermost/mattermost-server/v5/model" ) +// fixMention replaces any mentions in a post for the user with the user's real username. +func fixMention(post *model.Post, mentionMap model.UserMentionMap, user *model.User) { + if post == nil || len(mentionMap) == 0 { + return + } + + realUsername, ok := user.GetProp(KeyRemoteUsername) + if !ok { + return + } + + // there may be more than one mention for each user so we have to walk the whole map. + for mention, id := range mentionMap { + if id == user.Id && strings.Contains(mention, ":") { + post.Message = strings.ReplaceAll(post.Message, "@"+mention, "@"+realUsername) + } + } +} + +func sanitizeUserForSync(user *model.User) *model.User { + user.Password = model.NewId() + user.AuthData = nil + user.AuthService = "" + user.Roles = "system_user" + user.AllowMarketing = false + user.NotifyProps = model.StringMap{} + user.LastPasswordUpdate = 0 + user.LastPictureUpdate = 0 + user.FailedAttempts = 0 + user.MfaActive = false + user.MfaSecret = "" + + return user +} + // mungUsername creates a new username by combining username and remote cluster name, plus // a suffix to create uniqueness. If the resulting username exceeds the max length then // it is truncated and ellipses added. diff --git a/store/opentracinglayer/opentracinglayer.go b/store/opentracinglayer/opentracinglayer.go index 749ede2c0b..82acf4b109 100644 --- a/store/opentracinglayer/opentracinglayer.go +++ b/store/opentracinglayer/opentracinglayer.go @@ -5394,7 +5394,7 @@ func (s *OpenTracingLayerPostStore) GetPostsSince(options model.GetPostsSinceOpt return result, err } -func (s *OpenTracingLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { +func (s *OpenTracingLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.GetPostsSinceForSync") s.Root.Store.SetContext(newCtx) @@ -5403,13 +5403,13 @@ func (s *OpenTracingLayerPostStore) GetPostsSinceForSync(options model.GetPostsS }() defer span.Finish() - result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + result, resultVar1, err := s.PostStore.GetPostsSinceForSync(options, cursor, limit) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) } - return result, err + return result, resultVar1, err } func (s *OpenTracingLayerPostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { @@ -7274,16 +7274,52 @@ func (s *OpenTracingLayerSharedChannelStore) GetRemotesStatus(channelId string) return result, err } -func (s *OpenTracingLayerSharedChannelStore) GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { +func (s *OpenTracingLayerSharedChannelStore) GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { origCtx := s.Root.Store.Context() - span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetUser") + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetSingleUser") s.Root.Store.SetContext(newCtx) defer func() { s.Root.Store.SetContext(origCtx) }() defer span.Finish() - result, err := s.SharedChannelStore.GetUser(userID, channelID, remoteID) + result, err := s.SharedChannelStore.GetSingleUser(userID, channelID, remoteID) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetUsersForSync") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetUsersForSync(filter) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + +func (s *OpenTracingLayerSharedChannelStore) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.GetUsersForUser") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SharedChannelStore.GetUsersForUser(userID) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -7454,16 +7490,16 @@ func (s *OpenTracingLayerSharedChannelStore) UpdateRemote(remote *model.SharedCh return result, err } -func (s *OpenTracingLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { +func (s *OpenTracingLayerSharedChannelStore) UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { origCtx := s.Root.Store.Context() - span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateRemoteNextSyncAt") + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateRemoteCursor") s.Root.Store.SetContext(newCtx) defer func() { s.Root.Store.SetContext(origCtx) }() defer span.Finish() - err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateRemoteCursor(id, cursor) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -7472,7 +7508,7 @@ func (s *OpenTracingLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, s return err } -func (s *OpenTracingLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { +func (s *OpenTracingLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SharedChannelStore.UpdateUserLastSyncAt") s.Root.Store.SetContext(newCtx) @@ -7481,7 +7517,7 @@ func (s *OpenTracingLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syn }() defer span.Finish() - err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateUserLastSyncAt(userID, channelID, remoteID) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) diff --git a/store/retrylayer/retrylayer.go b/store/retrylayer/retrylayer.go index 7a8d5e8d0f..efe08cd132 100644 --- a/store/retrylayer/retrylayer.go +++ b/store/retrylayer/retrylayer.go @@ -5824,21 +5824,21 @@ func (s *RetryLayerPostStore) GetPostsSince(options model.GetPostsSinceOptions, } -func (s *RetryLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { +func (s *RetryLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) { tries := 0 for { - result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + result, resultVar1, err := s.PostStore.GetPostsSinceForSync(options, cursor, limit) if err == nil { - return result, nil + return result, resultVar1, nil } if !isRepeatableError(err) { - return result, err + return result, resultVar1, err } tries++ if tries >= 3 { err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return result, err + return result, resultVar1, err } } @@ -7896,11 +7896,51 @@ func (s *RetryLayerSharedChannelStore) GetRemotesStatus(channelId string) ([]*mo } -func (s *RetryLayerSharedChannelStore) GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { +func (s *RetryLayerSharedChannelStore) GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { tries := 0 for { - result, err := s.SharedChannelStore.GetUser(userID, channelID, remoteID) + result, err := s.SharedChannelStore.GetSingleUser(userID, channelID, remoteID) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetUsersForSync(filter) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + } + +} + +func (s *RetryLayerSharedChannelStore) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) { + + tries := 0 + for { + result, err := s.SharedChannelStore.GetUsersForUser(userID) if err == nil { return result, nil } @@ -8096,11 +8136,11 @@ func (s *RetryLayerSharedChannelStore) UpdateRemote(remote *model.SharedChannelR } -func (s *RetryLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { +func (s *RetryLayerSharedChannelStore) UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { tries := 0 for { - err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateRemoteCursor(id, cursor) if err == nil { return nil } @@ -8116,11 +8156,11 @@ func (s *RetryLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTim } -func (s *RetryLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { +func (s *RetryLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { tries := 0 for { - err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateUserLastSyncAt(userID, channelID, remoteID) if err == nil { return nil } diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index 9fe2b07729..622d913f03 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -1004,26 +1004,16 @@ func (s *SqlPostStore) HasAutoResponsePostByUserSince(options model.GetPostsSinc return exist > 0, nil } -func (s *SqlPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, _ /* allowFromCache */ bool) ([]*model.Post, error) { - if options.Limit < 0 || options.Limit > 1000 { - return nil, store.NewErrInvalidInput("Post", "", options.Limit) - } - - order := " ASC" - if options.SortDescending { - order = " DESC" - } - +func (s *SqlPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) { query := s.getQueryBuilder(). Select("*"). From("Posts"). - Where(sq.GtOrEq{"UpdateAt": options.Since}). - Where(sq.Eq{"ChannelId": options.ChannelId}). - Limit(uint64(options.Limit)). - OrderBy("CreateAt"+order, "DeleteAt", "Id") + Where(sq.Or{sq.Gt{"UpdateAt": cursor.LastPostUpdateAt}, sq.And{sq.Eq{"UpdateAt": cursor.LastPostUpdateAt}, sq.Gt{"Id": cursor.LastPostId}}}). + OrderBy("UpdateAt", "Id"). + Limit(uint64(limit)) - if options.Until > 0 { - query = query.Where(sq.LtOrEq{"UpdateAt": options.Until}) + if options.ChannelId != "" { + query = query.Where(sq.Eq{"ChannelId": options.ChannelId}) } if !options.IncludeDeleted { @@ -1034,24 +1024,22 @@ func (s *SqlPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOp query = query.Where(sq.NotEq{"COALESCE(Posts.RemoteId,'')": options.ExcludeRemoteId}) } - if options.Offset > 0 { - query = query.Offset(uint64(options.Offset)) - } - queryString, args, err := query.ToSql() if err != nil { - return nil, errors.Wrap(err, "getpostssinceforsync_tosql") + return nil, cursor, errors.Wrap(err, "getpostssinceforsync_tosql") } var posts []*model.Post - _, err = s.GetReplica().Select(&posts, queryString, args...) - if err != nil { - return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) + return nil, cursor, errors.Wrapf(err, "error getting Posts with channelId=%s", options.ChannelId) } - return posts, nil + if len(posts) != 0 { + cursor.LastPostUpdateAt = posts[len(posts)-1].UpdateAt + cursor.LastPostId = posts[len(posts)-1].Id + } + return posts, cursor, nil } func (s *SqlPostStore) GetPostsBefore(options model.GetPostsOptions) (*model.PostList, error) { diff --git a/store/sqlstore/shared_channel_store.go b/store/sqlstore/shared_channel_store.go index e5adfe6286..3ab4b09167 100644 --- a/store/sqlstore/shared_channel_store.go +++ b/store/sqlstore/shared_channel_store.go @@ -14,6 +14,10 @@ import ( "github.com/pkg/errors" ) +const ( + DefaultGetUsersForSyncLimit = 100 +) + type SqlSharedChannelStore struct { *SqlStore } @@ -40,6 +44,7 @@ func newSqlSharedChannelStore(sqlStore *SqlStore) store.SharedChannelStore { tableSharedChannelRemotes.ColMap("ChannelId").SetMaxSize(26) tableSharedChannelRemotes.ColMap("CreatorId").SetMaxSize(26) tableSharedChannelRemotes.ColMap("RemoteId").SetMaxSize(26) + tableSharedChannelRemotes.ColMap("LastPostId").SetMaxSize(26) tableSharedChannelRemotes.SetUniqueTogether("ChannelId", "RemoteId") tableSharedChannelUsers := db.AddTableWithName(model.SharedChannelUser{}, "SharedChannelUsers").SetKeys(false, "Id") @@ -467,20 +472,21 @@ func (s SqlSharedChannelStore) GetRemoteForUser(remoteId string, userId string) return &rc, nil } -// UpdateRemoteNextSyncAt updates the NextSyncAt timestamp for the specified SharedChannelRemote. -func (s SqlSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { +// UpdateRemoteCursor updates the LastPostUpdateAt timestamp and LastPostId for the specified SharedChannelRemote. +func (s SqlSharedChannelStore) UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { squery, args, err := s.getQueryBuilder(). Update("SharedChannelRemotes"). - Set("NextSyncAt", syncTime). + Set("LastPostUpdateAt", cursor.LastPostUpdateAt). + Set("LastPostId", cursor.LastPostId). Where(sq.Eq{"Id": id}). ToSql() if err != nil { - return errors.Wrap(err, "update_shared_channel_remote_next_sync_at_tosql") + return errors.Wrap(err, "update_shared_channel_remote_cursor_tosql") } result, err := s.GetMaster().Exec(squery, args...) if err != nil { - return errors.Wrap(err, "failed to update NextSyncAt for SharedChannelRemote") + return errors.Wrap(err, "failed to update cursor for SharedChannelRemote") } count, err := result.RowsAffected() @@ -556,8 +562,8 @@ func (s SqlSharedChannelStore) SaveUser(scUser *model.SharedChannelUser) (*model return scUser, nil } -// GetUser fetches a shared channel user based on user_id and remoteId. -func (s SqlSharedChannelStore) GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { +// GetSingleUser fetches a shared channel user based on userID, channelID and remoteID. +func (s SqlSharedChannelStore) GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { var scu model.SharedChannelUser squery, args, err := s.getQueryBuilder(). @@ -569,7 +575,7 @@ func (s SqlSharedChannelStore) GetUser(userID string, channelID string, remoteID ToSql() if err != nil { - return nil, errors.Wrapf(err, "getsharedchanneluser_tosql") + return nil, errors.Wrapf(err, "getsharedchannelsingleuser_tosql") } if err := s.GetReplica().SelectOne(&scu, squery, args...); err != nil { @@ -581,20 +587,104 @@ func (s SqlSharedChannelStore) GetUser(userID string, channelID string, remoteID return &scu, nil } -// UpdateUserLastSyncAt updates the LastSyncAt timestamp for the specified SharedChannelUser. -func (s SqlSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { +// GetUsersForUser fetches all shared channel user records based on userID. +func (s SqlSharedChannelStore) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) { squery, args, err := s.getQueryBuilder(). - Update("SharedChannelUsers"). - Set("LastSyncAt", syncTime). - Where(sq.Eq{"Id": id}). + Select("*"). + From("SharedChannelUsers"). + Where(sq.Eq{"SharedChannelUsers.UserId": userID}). ToSql() + if err != nil { - return errors.Wrap(err, "update_shared_channel_user_last_sync_at_tosql") + return nil, errors.Wrapf(err, "getsharedchanneluser_tosql") } - result, err := s.GetMaster().Exec(squery, args...) + var users []*model.SharedChannelUser + if _, err := s.GetReplica().Select(&users, squery, args...); err != nil { + if err == sql.ErrNoRows { + return make([]*model.SharedChannelUser, 0), nil + } + return nil, errors.Wrapf(err, "failed to find shared channel user with UserId=%s", userID) + } + return users, nil +} + +// GetUsersForSync fetches all shared channel users that need to be synchronized, meaning their +// `SharedChannelUsers.LastSyncAt` is less than or equal to `User.UpdateAt`. +func (s SqlSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { + if filter.Limit <= 0 { + filter.Limit = DefaultGetUsersForSyncLimit + } + + query := s.getQueryBuilder(). + Select("u.*"). + Distinct(). + From("Users AS u"). + Join("SharedChannelUsers AS scu ON u.Id = scu.UserId"). + OrderBy("u.Id"). + Limit(filter.Limit) + + if filter.CheckProfileImage { + query = query.Where("scu.LastSyncAt < u.LastPictureUpdate") + } else { + query = query.Where("scu.LastSyncAt < u.UpdateAt") + } + + if filter.ChannelID != "" { + query = query.Where(sq.Eq{"scu.ChannelId": filter.ChannelID}) + } + + sqlQuery, args, err := query.ToSql() if err != nil { - return errors.Wrap(err, "failed to update LastSycnAt for SharedChannelUser") + return nil, errors.Wrapf(err, "getsharedchannelusersforsync_tosql") + } + + var users []*model.User + if _, err := s.GetReplica().Select(&users, sqlQuery, args...); err != nil { + if err == sql.ErrNoRows { + return make([]*model.User, 0), nil + } + return nil, errors.Wrapf(err, "failed to fetch shared channel users with ChannelId=%s", + filter.ChannelID) + } + return users, nil +} + +// UpdateUserLastSyncAt updates the LastSyncAt timestamp for the specified SharedChannelUser. +func (s SqlSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { + args := map[string]interface{}{"UserId": userID, "ChannelId": channelID, "RemoteId": remoteID} + + var query string + if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { + query = ` + UPDATE + SharedChannelUsers AS scu + SET + LastSyncAt = GREATEST(Users.UpdateAt, Users.LastPictureUpdate) + FROM + Users + WHERE + Users.Id = scu.UserId AND scu.UserId = :UserId AND scu.ChannelId = :ChannelId AND scu.RemoteId = :RemoteId + ` + } else if s.DriverName() == model.DATABASE_DRIVER_MYSQL { + query = ` + UPDATE + SharedChannelUsers AS scu + INNER JOIN + Users ON scu.UserId = Users.Id + SET + LastSyncAt = GREATEST(Users.UpdateAt, Users.LastPictureUpdate) + WHERE + scu.UserId = :UserId AND scu.ChannelId = :ChannelId AND scu.RemoteId = :RemoteId + ` + } else { + return errors.New("unsupported DB driver " + s.DriverName()) + } + + result, err := s.GetMaster().Exec(query, args) + if err != nil { + return fmt.Errorf("failed to update LastSyncAt for SharedChannelUser with userId=%s, channelId=%s, remoteId=%s: %w", + userID, channelID, remoteID, err) } count, err := result.RowsAffected() @@ -602,7 +692,7 @@ func (s SqlSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) e return errors.Wrap(err, "failed to determine rows affected") } if count == 0 { - return fmt.Errorf("id not found: %s", id) + return fmt.Errorf("SharedChannelUser not found: userId=%s, channelId=%s, remoteId=%s", userID, channelID, remoteID) } return nil } diff --git a/store/sqlstore/upgrade.go b/store/sqlstore/upgrade.go index 0c0eebc774..06e080df61 100644 --- a/store/sqlstore/upgrade.go +++ b/store/sqlstore/upgrade.go @@ -1067,6 +1067,8 @@ func upgradeDatabaseToVersion536(sqlStore *SqlStore) { //if shouldPerformUpgrade(sqlStore, Version5350, Version5360) { sqlStore.CreateColumnIfNotExists("SharedChannelUsers", "ChannelId", "VARCHAR(26)", "VARCHAR(26)", "") + sqlStore.CreateColumnIfNotExists("SharedChannelRemotes", "LastPostUpdateAt", "bigint", "bigint", "0") + sqlStore.CreateColumnIfNotExists("SharedChannelRemotes", "LastPostId", "VARCHAR(26)", "VARCHAR(26)", "") // timed dnd status support sqlStore.CreateColumnIfNotExistsNoDefault("Status", "DNDEndTime", "BIGINT", "BIGINT") diff --git a/store/store.go b/store/store.go index fcd0fa0ec5..1c6df82831 100644 --- a/store/store.go +++ b/store/store.go @@ -341,7 +341,7 @@ type PostStore interface { SearchPostsInTeamForUser(paramsList []*model.SearchParams, userID, teamID string, page, perPage int) (*model.PostSearchResults, error) GetOldestEntityCreationTime() (int64, error) HasAutoResponsePostByUserSince(options model.GetPostsSinceOptions, userId string) (bool, error) - GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) + GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) } type UserStore interface { @@ -853,13 +853,15 @@ type SharedChannelStore interface { GetRemoteForUser(remoteId string, userId string) (*model.RemoteCluster, error) GetRemoteByIds(channelId string, remoteId string) (*model.SharedChannelRemote, error) GetRemotes(opts model.SharedChannelRemoteFilterOpts) ([]*model.SharedChannelRemote, error) - UpdateRemoteNextSyncAt(id string, syncTime int64) error + UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error DeleteRemote(remoteId string) (bool, error) GetRemotesStatus(channelId string) ([]*model.SharedChannelRemoteStatus, error) SaveUser(remote *model.SharedChannelUser) (*model.SharedChannelUser, error) - GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) - UpdateUserLastSyncAt(id string, syncTime int64) error + GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) + GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) + GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) + UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error SaveAttachment(remote *model.SharedChannelAttachment) (*model.SharedChannelAttachment, error) UpsertAttachment(remote *model.SharedChannelAttachment) (string, error) diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index f556fef8e4..0625d4b47b 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -538,27 +538,34 @@ func (_m *PostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFrom return r0, r1 } -// GetPostsSinceForSync provides a mock function with given fields: options, allowFromCache -func (_m *PostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { - ret := _m.Called(options, allowFromCache) +// GetPostsSinceForSync provides a mock function with given fields: options, cursor, limit +func (_m *PostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) { + ret := _m.Called(options, cursor, limit) var r0 []*model.Post - if rf, ok := ret.Get(0).(func(model.GetPostsSinceForSyncOptions, bool) []*model.Post); ok { - r0 = rf(options, allowFromCache) + if rf, ok := ret.Get(0).(func(model.GetPostsSinceForSyncOptions, model.GetPostsSinceForSyncCursor, int) []*model.Post); ok { + r0 = rf(options, cursor, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*model.Post) } } - var r1 error - if rf, ok := ret.Get(1).(func(model.GetPostsSinceForSyncOptions, bool) error); ok { - r1 = rf(options, allowFromCache) + var r1 model.GetPostsSinceForSyncCursor + if rf, ok := ret.Get(1).(func(model.GetPostsSinceForSyncOptions, model.GetPostsSinceForSyncCursor, int) model.GetPostsSinceForSyncCursor); ok { + r1 = rf(options, cursor, limit) } else { - r1 = ret.Error(1) + r1 = ret.Get(1).(model.GetPostsSinceForSyncCursor) } - return r0, r1 + var r2 error + if rf, ok := ret.Get(2).(func(model.GetPostsSinceForSyncOptions, model.GetPostsSinceForSyncCursor, int) error); ok { + r2 = rf(options, cursor, limit) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // GetRepliesForExport provides a mock function with given fields: parentID diff --git a/store/storetest/mocks/SharedChannelStore.go b/store/storetest/mocks/SharedChannelStore.go index 505c6190fd..0973ac908e 100644 --- a/store/storetest/mocks/SharedChannelStore.go +++ b/store/storetest/mocks/SharedChannelStore.go @@ -261,8 +261,8 @@ func (_m *SharedChannelStore) GetRemotesStatus(channelId string) ([]*model.Share return r0, r1 } -// GetUser provides a mock function with given fields: userID, channelID, remoteID -func (_m *SharedChannelStore) GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { +// GetSingleUser provides a mock function with given fields: userID, channelID, remoteID +func (_m *SharedChannelStore) GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { ret := _m.Called(userID, channelID, remoteID) var r0 *model.SharedChannelUser @@ -284,6 +284,52 @@ func (_m *SharedChannelStore) GetUser(userID string, channelID string, remoteID return r0, r1 } +// GetUsersForSync provides a mock function with given fields: filter +func (_m *SharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { + ret := _m.Called(filter) + + var r0 []*model.User + if rf, ok := ret.Get(0).(func(model.GetUsersForSyncFilter) []*model.User); ok { + r0 = rf(filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(model.GetUsersForSyncFilter) error); ok { + r1 = rf(filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetUsersForUser provides a mock function with given fields: userID +func (_m *SharedChannelStore) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) { + ret := _m.Called(userID) + + var r0 []*model.SharedChannelUser + if rf, ok := ret.Get(0).(func(string) []*model.SharedChannelUser); ok { + r0 = rf(userID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.SharedChannelUser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(userID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // HasChannel provides a mock function with given fields: channelID func (_m *SharedChannelStore) HasChannel(channelID string) (bool, error) { ret := _m.Called(channelID) @@ -478,13 +524,13 @@ func (_m *SharedChannelStore) UpdateRemote(remote *model.SharedChannelRemote) (* return r0, r1 } -// UpdateRemoteNextSyncAt provides a mock function with given fields: id, syncTime -func (_m *SharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { - ret := _m.Called(id, syncTime) +// UpdateRemoteCursor provides a mock function with given fields: id, cursor +func (_m *SharedChannelStore) UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { + ret := _m.Called(id, cursor) var r0 error - if rf, ok := ret.Get(0).(func(string, int64) error); ok { - r0 = rf(id, syncTime) + if rf, ok := ret.Get(0).(func(string, model.GetPostsSinceForSyncCursor) error); ok { + r0 = rf(id, cursor) } else { r0 = ret.Error(0) } @@ -492,13 +538,13 @@ func (_m *SharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) return r0 } -// UpdateUserLastSyncAt provides a mock function with given fields: id, syncTime -func (_m *SharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { - ret := _m.Called(id, syncTime) +// UpdateUserLastSyncAt provides a mock function with given fields: userID, channelID, remoteID +func (_m *SharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { + ret := _m.Called(userID, channelID, remoteID) var r0 error - if rf, ok := ret.Get(0).(func(string, int64) error); ok { - r0 = rf(id, syncTime) + if rf, ok := ret.Get(0).(func(string, string, string) error); ok { + r0 = rf(userID, channelID, remoteID) } else { r0 = ret.Error(0) } diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index d6ffb5b00e..ea547802e0 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -57,6 +57,7 @@ func TestPostStore(t *testing.T, ss store.Store, s SqlStore) { t.Run("GetDirectPostParentsForExportAfterBatched", func(t *testing.T) { testPostStoreGetDirectPostParentsForExportAfterBatched(t, ss, s) }) t.Run("GetForThread", func(t *testing.T) { testPostStoreGetForThread(t, ss) }) t.Run("HasAutoResponsePostByUserSince", func(t *testing.T) { testHasAutoResponsePostByUserSince(t, ss) }) + t.Run("GetPostsSinceForSync", func(t *testing.T) { testGetPostsSinceForSync(t, ss, s) }) } func testPostStoreSave(t *testing.T, ss store.Store) { @@ -3013,3 +3014,118 @@ func testHasAutoResponsePostByUserSince(t *testing.T, ss store.Store) { assert.False(t, exists) }) } + +func testGetPostsSinceForSync(t *testing.T, ss store.Store, s SqlStore) { + // create some posts. + channelID := model.NewId() + remoteID := model.NewString(model.NewId()) + first := model.GetMillis() + + data := []*model.Post{ + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 0"}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 1"}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 2"}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 3", RemoteId: remoteID}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 4", RemoteId: remoteID}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 5", RemoteId: remoteID}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 6", RemoteId: remoteID}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 7"}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 8", DeleteAt: model.GetMillis()}, + {Id: model.NewId(), ChannelId: channelID, UserId: model.NewId(), Message: "test post 9", DeleteAt: model.GetMillis()}, + } + + for i, p := range data { + p.UpdateAt = first + (int64(i) * 300000) + if p.RemoteId == nil { + p.RemoteId = model.NewString(model.NewId()) + } + _, err := ss.Post().Save(p) + require.NoError(t, err, "couldn't save post") + } + + t.Run("Invalid channel id", func(t *testing.T) { + opt := model.GetPostsSinceForSyncOptions{ + ChannelId: model.NewId(), + } + cursor := model.GetPostsSinceForSyncCursor{} + posts, cursorOut, err := ss.Post().GetPostsSinceForSync(opt, cursor, 100) + require.NoError(t, err) + require.Empty(t, posts, "should return zero posts") + require.Equal(t, cursor, cursorOut) + }) + + t.Run("Get by channel, exclude remotes, exclude deleted", func(t *testing.T) { + opt := model.GetPostsSinceForSyncOptions{ + ChannelId: channelID, + ExcludeRemoteId: *remoteID, + } + cursor := model.GetPostsSinceForSyncCursor{} + posts, _, err := ss.Post().GetPostsSinceForSync(opt, cursor, 100) + require.NoError(t, err) + + require.ElementsMatch(t, getPostIds(data[0:3], data[7]), getPostIds(posts)) + }) + + t.Run("Include deleted", func(t *testing.T) { + opt := model.GetPostsSinceForSyncOptions{ + ChannelId: channelID, + IncludeDeleted: true, + } + cursor := model.GetPostsSinceForSyncCursor{} + posts, _, err := ss.Post().GetPostsSinceForSync(opt, cursor, 100) + require.NoError(t, err) + + require.ElementsMatch(t, getPostIds(data), getPostIds(posts)) + }) + + t.Run("Limit and cursor", func(t *testing.T) { + opt := model.GetPostsSinceForSyncOptions{ + ChannelId: channelID, + } + cursor := model.GetPostsSinceForSyncCursor{} + posts1, cursor, err := ss.Post().GetPostsSinceForSync(opt, cursor, 5) + require.NoError(t, err) + require.Len(t, posts1, 5, "should get 5 posts") + + posts2, _, err := ss.Post().GetPostsSinceForSync(opt, cursor, 5) + require.NoError(t, err) + require.Len(t, posts2, 3, "should get 3 posts") + + require.ElementsMatch(t, getPostIds(data[0:8]), getPostIds(posts1, posts2...)) + }) + + t.Run("UpdateAt collisions", func(t *testing.T) { + // this test requires all the UpdateAt timestamps to be the same. + args := map[string]interface{}{"UpdateAt": model.GetMillis()} + result, err := s.GetMaster().Exec("UPDATE Posts SET UpdateAt = :UpdateAt", args) + require.NoError(t, err) + rows, err := result.RowsAffected() + require.NoError(t, err) + require.Greater(t, rows, int64(0)) + + opt := model.GetPostsSinceForSyncOptions{ + ChannelId: channelID, + } + cursor := model.GetPostsSinceForSyncCursor{} + posts1, cursor, err := ss.Post().GetPostsSinceForSync(opt, cursor, 5) + require.NoError(t, err) + require.Len(t, posts1, 5, "should get 5 posts") + + posts2, _, err := ss.Post().GetPostsSinceForSync(opt, cursor, 5) + require.NoError(t, err) + require.Len(t, posts2, 3, "should get 3 posts") + + require.ElementsMatch(t, getPostIds(data[0:8]), getPostIds(posts1, posts2...)) + }) +} + +func getPostIds(posts []*model.Post, morePosts ...*model.Post) []string { + ids := make([]string, 0, len(posts)+len(morePosts)) + for _, p := range posts { + ids = append(ids, p.Id) + } + for _, p := range morePosts { + ids = append(ids, p.Id) + } + return ids +} diff --git a/store/storetest/shared_channel_store.go b/store/storetest/shared_channel_store.go index fe85e04ea2..04ed96d058 100644 --- a/store/storetest/shared_channel_store.go +++ b/store/storetest/shared_channel_store.go @@ -30,11 +30,13 @@ func TestSharedChannelStore(t *testing.T, ss store.Store, s SqlStore) { t.Run("GetSharedChannelRemotes", func(t *testing.T) { testGetSharedChannelRemotes(t, ss) }) t.Run("HasRemote", func(t *testing.T) { testHasRemote(t, ss) }) t.Run("GetRemoteForUser", func(t *testing.T) { testGetRemoteForUser(t, ss) }) - t.Run("UpdateSharedChannelRemoteNextSyncAt", func(t *testing.T) { testUpdateSharedChannelRemoteNextSyncAt(t, ss) }) + t.Run("UpdateSharedChannelRemoteNextSyncAt", func(t *testing.T) { testUpdateSharedChannelRemoteCursor(t, ss) }) t.Run("DeleteSharedChannelRemote", func(t *testing.T) { testDeleteSharedChannelRemote(t, ss) }) t.Run("SaveSharedChannelUser", func(t *testing.T) { testSaveSharedChannelUser(t, ss) }) + t.Run("GetSharedChannelSingleUser", func(t *testing.T) { testGetSingleSharedChannelUser(t, ss) }) t.Run("GetSharedChannelUser", func(t *testing.T) { testGetSharedChannelUser(t, ss) }) + t.Run("GetSharedChannelUsersForSync", func(t *testing.T) { testGetSharedChannelUsersForSync(t, ss) }) t.Run("UpdateSharedChannelUserLastSyncAt", func(t *testing.T) { testUpdateSharedChannelUserLastSyncAt(t, ss) }) t.Run("SaveSharedChannelAttachment", func(t *testing.T) { testSaveSharedChannelAttachment(t, ss) }) @@ -714,7 +716,7 @@ func testGetRemoteForUser(t *testing.T, ss store.Store) { }) } -func testUpdateSharedChannelRemoteNextSyncAt(t *testing.T, ss store.Store) { +func testUpdateSharedChannelRemoteCursor(t *testing.T, ss store.Store) { channel, err := createTestChannel(ss, "test_remote_update_next_sync_at") require.NoError(t, err) @@ -728,18 +730,25 @@ func testUpdateSharedChannelRemoteNextSyncAt(t *testing.T, ss store.Store) { require.NoError(t, err, "couldn't save remote", err) future := model.GetMillis() + 3600000 // 1 hour in the future + postID := model.NewId() + + cursor := model.GetPostsSinceForSyncCursor{ + LastPostUpdateAt: future, + LastPostId: postID, + } t.Run("Update NextSyncAt for remote", func(t *testing.T) { - err := ss.SharedChannel().UpdateRemoteNextSyncAt(remoteSaved.Id, future) + err := ss.SharedChannel().UpdateRemoteCursor(remoteSaved.Id, cursor) require.NoError(t, err, "update NextSyncAt should not error", err) r, err := ss.SharedChannel().GetRemote(remoteSaved.Id) require.NoError(t, err) - require.Equal(t, future, r.NextSyncAt) + require.Equal(t, future, r.LastPostUpdateAt) + require.Equal(t, postID, r.LastPostId) }) t.Run("Update NextSyncAt for non-existent shared channel remote", func(t *testing.T) { - err := ss.SharedChannel().UpdateRemoteNextSyncAt(model.NewId(), future) + err := ss.SharedChannel().UpdateRemoteCursor(model.NewId(), cursor) require.Error(t, err, "update non-existent remote should error", err) }) } @@ -862,7 +871,7 @@ func testSaveSharedChannelUser(t *testing.T, ss store.Store) { }) } -func testGetSharedChannelUser(t *testing.T, ss store.Store) { +func testGetSingleSharedChannelUser(t *testing.T, ss store.Store) { scUser := &model.SharedChannelUser{ UserId: model.NewId(), RemoteId: model.NewId(), @@ -873,7 +882,7 @@ func testGetSharedChannelUser(t *testing.T, ss store.Store) { require.NoError(t, err, "could not save user", err) t.Run("Get existing shared channel user", func(t *testing.T) { - r, err := ss.SharedChannel().GetUser(userSaved.UserId, userSaved.ChannelId, userSaved.RemoteId) + r, err := ss.SharedChannel().GetSingleUser(userSaved.UserId, userSaved.ChannelId, userSaved.RemoteId) require.NoError(t, err, "couldn't get shared channel user", err) require.Equal(t, userSaved.Id, r.Id) @@ -883,35 +892,174 @@ func testGetSharedChannelUser(t *testing.T, ss store.Store) { }) t.Run("Get non-existent shared channel user", func(t *testing.T) { - u, err := ss.SharedChannel().GetUser(model.NewId(), model.NewId(), model.NewId()) + u, err := ss.SharedChannel().GetSingleUser(model.NewId(), model.NewId(), model.NewId()) require.Error(t, err) require.Nil(t, u) }) } -func testUpdateSharedChannelUserLastSyncAt(t *testing.T, ss store.Store) { - scUser := &model.SharedChannelUser{ - UserId: model.NewId(), - RemoteId: model.NewId(), - ChannelId: model.NewId(), +func testGetSharedChannelUser(t *testing.T, ss store.Store) { + userId := model.NewId() + for i := 0; i < 10; i++ { + scUser := &model.SharedChannelUser{ + UserId: userId, + RemoteId: model.NewId(), + ChannelId: model.NewId(), + } + _, err := ss.SharedChannel().SaveUser(scUser) + require.NoError(t, err, "could not save user", err) } - userSaved, err := ss.SharedChannel().SaveUser(scUser) + t.Run("Get existing shared channel user", func(t *testing.T) { + scus, err := ss.SharedChannel().GetUsersForUser(userId) + require.NoError(t, err, "couldn't get shared channel user", err) + + require.Len(t, scus, 10, "should be 10 shared channel user records") + require.Equal(t, userId, scus[0].UserId) + }) + + t.Run("Get non-existent shared channel user", func(t *testing.T) { + scus, err := ss.SharedChannel().GetUsersForUser(model.NewId()) + require.NoError(t, err, "should not error when not found") + require.Empty(t, scus, "should be empty") + }) +} + +func testGetSharedChannelUsersForSync(t *testing.T, ss store.Store) { + channelID := model.NewId() + remoteID := model.NewId() + earlier := model.GetMillis() - 300000 + later := model.GetMillis() + 300000 + + var users []*model.User + for i := 0; i < 10; i++ { // need real users + u := &model.User{ + Username: model.NewId(), + Email: model.NewId() + "@example.com", + LastPictureUpdate: model.GetMillis(), + } + u, err := ss.User().Save(u) + require.NoError(t, err) + users = append(users, u) + } + + data := []model.SharedChannelUser{ + {UserId: users[0].Id, ChannelId: model.NewId(), RemoteId: model.NewId(), LastSyncAt: later}, + {UserId: users[1].Id, ChannelId: model.NewId(), RemoteId: model.NewId(), LastSyncAt: earlier}, + {UserId: users[1].Id, ChannelId: model.NewId(), RemoteId: model.NewId(), LastSyncAt: earlier}, + {UserId: users[1].Id, ChannelId: channelID, RemoteId: remoteID, LastSyncAt: later}, + {UserId: users[2].Id, ChannelId: channelID, RemoteId: model.NewId(), LastSyncAt: later}, + {UserId: users[3].Id, ChannelId: channelID, RemoteId: model.NewId(), LastSyncAt: earlier}, + {UserId: users[4].Id, ChannelId: channelID, RemoteId: model.NewId(), LastSyncAt: later}, + {UserId: users[5].Id, ChannelId: channelID, RemoteId: remoteID, LastSyncAt: earlier}, + {UserId: users[6].Id, ChannelId: channelID, RemoteId: remoteID, LastSyncAt: later}, + } + + for i, u := range data { + scu := &model.SharedChannelUser{ + UserId: u.UserId, + ChannelId: u.ChannelId, + RemoteId: u.RemoteId, + LastSyncAt: u.LastSyncAt, + } + _, err := ss.SharedChannel().SaveUser(scu) + require.NoError(t, err, "could not save user #", i, err) + } + + t.Run("Filter by channelId", func(t *testing.T) { + filter := model.GetUsersForSyncFilter{ + CheckProfileImage: false, + ChannelID: channelID, + } + usersFound, err := ss.SharedChannel().GetUsersForSync(filter) + require.NoError(t, err, "shouldn't error getting users", err) + require.Len(t, usersFound, 2) + for _, user := range usersFound { + require.Contains(t, []string{users[3].Id, users[5].Id}, user.Id) + } + }) + + t.Run("Filter by channelId for profile image", func(t *testing.T) { + filter := model.GetUsersForSyncFilter{ + CheckProfileImage: true, + ChannelID: channelID, + } + usersFound, err := ss.SharedChannel().GetUsersForSync(filter) + require.NoError(t, err, "shouldn't error getting users", err) + require.Len(t, usersFound, 2) + for _, user := range usersFound { + require.Contains(t, []string{users[3].Id, users[5].Id}, user.Id) + } + }) + + t.Run("Filter by channelId with Limit", func(t *testing.T) { + filter := model.GetUsersForSyncFilter{ + CheckProfileImage: true, + ChannelID: channelID, + Limit: 1, + } + usersFound, err := ss.SharedChannel().GetUsersForSync(filter) + require.NoError(t, err, "shouldn't error getting users", err) + require.Len(t, usersFound, 1) + }) +} + +func testUpdateSharedChannelUserLastSyncAt(t *testing.T, ss store.Store) { + u1 := &model.User{ + Username: model.NewId(), + Email: model.NewId() + "@example.com", + LastPictureUpdate: model.GetMillis() - 300000, // 5 mins + } + u1, err := ss.User().Save(u1) + require.NoError(t, err) + + u2 := &model.User{ + Username: model.NewId(), + Email: model.NewId() + "@example.com", + LastPictureUpdate: model.GetMillis() + 300000, + } + u2, err = ss.User().Save(u2) + require.NoError(t, err) + + channelID := model.NewId() + remoteID := model.NewId() + + scUser1 := &model.SharedChannelUser{ + UserId: u1.Id, + RemoteId: remoteID, + ChannelId: channelID, + } + _, err = ss.SharedChannel().SaveUser(scUser1) require.NoError(t, err, "couldn't save user", err) - future := model.GetMillis() + 3600000 // 1 hour in the future + scUser2 := &model.SharedChannelUser{ + UserId: u2.Id, + RemoteId: remoteID, + ChannelId: channelID, + } + _, err = ss.SharedChannel().SaveUser(scUser2) + require.NoError(t, err, "couldn't save user", err) - t.Run("Update LastSyncAt for user", func(t *testing.T) { - err := ss.SharedChannel().UpdateUserLastSyncAt(userSaved.Id, future) + t.Run("Update LastSyncAt for user via UpdateAt", func(t *testing.T) { + err := ss.SharedChannel().UpdateUserLastSyncAt(u1.Id, channelID, remoteID) require.NoError(t, err, "updateLastSyncAt should not error", err) - u, err := ss.SharedChannel().GetUser(userSaved.UserId, userSaved.ChannelId, userSaved.RemoteId) + scu, err := ss.SharedChannel().GetSingleUser(u1.Id, channelID, remoteID) require.NoError(t, err) - require.Equal(t, future, u.LastSyncAt) + require.Equal(t, u1.UpdateAt, scu.LastSyncAt) + }) + + t.Run("Update LastSyncAt for user via LastPictureUpdate", func(t *testing.T) { + err := ss.SharedChannel().UpdateUserLastSyncAt(u2.Id, channelID, remoteID) + require.NoError(t, err, "updateLastSyncAt should not error", err) + + scu, err := ss.SharedChannel().GetSingleUser(u2.Id, channelID, remoteID) + require.NoError(t, err) + require.Equal(t, u2.LastPictureUpdate, scu.LastSyncAt) }) t.Run("Update LastSyncAt for non-existent shared channel user", func(t *testing.T) { - err := ss.SharedChannel().UpdateUserLastSyncAt(model.NewId(), future) + err := ss.SharedChannel().UpdateUserLastSyncAt(model.NewId(), channelID, remoteID) require.Error(t, err, "update non-existent user should error", err) }) } diff --git a/store/timerlayer/timerlayer.go b/store/timerlayer/timerlayer.go index fabcc22e82..5332a56522 100644 --- a/store/timerlayer/timerlayer.go +++ b/store/timerlayer/timerlayer.go @@ -4890,10 +4890,10 @@ func (s *TimerLayerPostStore) GetPostsSince(options model.GetPostsSinceOptions, return result, err } -func (s *TimerLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, allowFromCache bool) ([]*model.Post, error) { +func (s *TimerLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceForSyncOptions, cursor model.GetPostsSinceForSyncCursor, limit int) ([]*model.Post, model.GetPostsSinceForSyncCursor, error) { start := timemodule.Now() - result, err := s.PostStore.GetPostsSinceForSync(options, allowFromCache) + result, resultVar1, err := s.PostStore.GetPostsSinceForSync(options, cursor, limit) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -4903,7 +4903,7 @@ func (s *TimerLayerPostStore) GetPostsSinceForSync(options model.GetPostsSinceFo } s.Root.Metrics.ObserveStoreMethodDuration("PostStore.GetPostsSinceForSync", success, elapsed) } - return result, err + return result, resultVar1, err } func (s *TimerLayerPostStore) GetRepliesForExport(parentID string) ([]*model.ReplyForExport, error) { @@ -6568,10 +6568,10 @@ func (s *TimerLayerSharedChannelStore) GetRemotesStatus(channelId string) ([]*mo return result, err } -func (s *TimerLayerSharedChannelStore) GetUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { +func (s *TimerLayerSharedChannelStore) GetSingleUser(userID string, channelID string, remoteID string) (*model.SharedChannelUser, error) { start := timemodule.Now() - result, err := s.SharedChannelStore.GetUser(userID, channelID, remoteID) + result, err := s.SharedChannelStore.GetSingleUser(userID, channelID, remoteID) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -6579,7 +6579,39 @@ func (s *TimerLayerSharedChannelStore) GetUser(userID string, channelID string, if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUser", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetSingleUser", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetUsersForSync(filter model.GetUsersForSyncFilter) ([]*model.User, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetUsersForSync(filter) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUsersForSync", success, elapsed) + } + return result, err +} + +func (s *TimerLayerSharedChannelStore) GetUsersForUser(userID string) ([]*model.SharedChannelUser, error) { + start := timemodule.Now() + + result, err := s.SharedChannelStore.GetUsersForUser(userID) + + elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.GetUsersForUser", success, elapsed) } return result, err } @@ -6728,10 +6760,10 @@ func (s *TimerLayerSharedChannelStore) UpdateRemote(remote *model.SharedChannelR return result, err } -func (s *TimerLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTime int64) error { +func (s *TimerLayerSharedChannelStore) UpdateRemoteCursor(id string, cursor model.GetPostsSinceForSyncCursor) error { start := timemodule.Now() - err := s.SharedChannelStore.UpdateRemoteNextSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateRemoteCursor(id, cursor) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -6739,15 +6771,15 @@ func (s *TimerLayerSharedChannelStore) UpdateRemoteNextSyncAt(id string, syncTim if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateRemoteNextSyncAt", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("SharedChannelStore.UpdateRemoteCursor", success, elapsed) } return err } -func (s *TimerLayerSharedChannelStore) UpdateUserLastSyncAt(id string, syncTime int64) error { +func (s *TimerLayerSharedChannelStore) UpdateUserLastSyncAt(userID string, channelID string, remoteID string) error { start := timemodule.Now() - err := s.SharedChannelStore.UpdateUserLastSyncAt(id, syncTime) + err := s.SharedChannelStore.UpdateUserLastSyncAt(userID, channelID, remoteID) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil {