diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index b3dcbed304..bc5ee9e2d7 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -3276,6 +3276,11 @@ func (a *App) MoveChannel(c request.CTX, team *model.Team, channel *model.Channe } } + // Update the threads within this channel to the new team + if err := a.Srv().Store().Thread().UpdateTeamIdForChannelThreads(channel.Id, team.Id); err != nil { + c.Logger().Warn("error while updating threads after channel move", mlog.Err(err)) + } + if err := a.RemoveUsersFromChannelNotMemberOfTeam(c, user, channel, team); err != nil { c.Logger().Warn("error while removing non-team member users", mlog.Err(err)) } diff --git a/server/channels/app/channel_test.go b/server/channels/app/channel_test.go index 02bd703a51..340722ec2a 100644 --- a/server/channels/app/channel_test.go +++ b/server/channels/app/channel_test.go @@ -247,6 +247,59 @@ func TestMoveChannel(t *testing.T) { require.Equal(t, model.SidebarCategoryChannels, categories.Categories[1].Type) assert.Contains(t, categories.Categories[1].Channels, channel.Id) }) + + t.Run("should update threads when moving channels between teams", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + sourceTeam := th.CreateTeam() + targetTeam := th.CreateTeam() + channel := th.CreateChannel(th.Context, sourceTeam) + + th.LinkUserToTeam(th.BasicUser, sourceTeam) + th.LinkUserToTeam(th.BasicUser, targetTeam) + th.AddUserToChannel(th.BasicUser, channel) + + // Create a thread in the channel + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: channel.Id, + Message: "test", + } + post, appErr := th.App.CreatePost(th.Context, post, channel, model.CreatePostFlags{}) + require.Nil(t, appErr) + + // Post a reply to the thread + reply := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: channel.Id, + RootId: post.Id, + Message: "reply", + } + _, appErr = th.App.CreatePost(th.Context, reply, channel, model.CreatePostFlags{}) + require.Nil(t, appErr) + + // Check that the thread count before move + threads, appErr := th.App.GetThreadsForUser(th.BasicUser.Id, targetTeam.Id, model.GetUserThreadsOpts{}) + require.Nil(t, appErr) + + require.Zero(t, threads.Total) + + // Move the channel to the target team + appErr = th.App.MoveChannel(th.Context, targetTeam, channel, th.BasicUser) + require.Nil(t, appErr) + + // Check that the thread was moved + threads, appErr = th.App.GetThreadsForUser(th.BasicUser.Id, targetTeam.Id, model.GetUserThreadsOpts{}) + require.Nil(t, appErr) + + require.Equal(t, int64(1), threads.Total) + // Check that the thread count after move + threads, appErr = th.App.GetThreadsForUser(th.BasicUser.Id, sourceTeam.Id, model.GetUserThreadsOpts{}) + require.Nil(t, appErr) + + require.Zero(t, threads.Total) + }) } func TestRemoveUsersFromChannelNotMemberOfTeam(t *testing.T) { diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 0e29f23f1b..1a6f19d3b2 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -13545,6 +13545,27 @@ func (s *RetryLayerThreadStore) UpdateMembership(membership *model.ThreadMembers } +func (s *RetryLayerThreadStore) UpdateTeamIdForChannelThreads(channelId string, teamId string) error { + + tries := 0 + for { + err := s.ThreadStore.UpdateTeamIdForChannelThreads(channelId, teamId) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) { s.TokenStore.Cleanup(expiryTime) diff --git a/server/channels/store/sqlstore/thread_store.go b/server/channels/store/sqlstore/thread_store.go index 97d293c92c..d46a2e73fc 100644 --- a/server/channels/store/sqlstore/thread_store.go +++ b/server/channels/store/sqlstore/thread_store.go @@ -1139,3 +1139,25 @@ func (s *SqlThreadStore) updateThreadParticipantsForUserTx(trx *sqlxTxWrapper, p return nil } + +// UpdateTeamIdForChannelThreads updates the team id for all threads in a channel. +// Specifically used when a channel is moved to a different team. +// If a user is not member of the new team, the threads will be deleted by the +// channel move process. +func (s *SqlThreadStore) UpdateTeamIdForChannelThreads(channelId, teamId string) error { + query := s.getQueryBuilder(). + Update("Threads"). + Set("ThreadTeamId", teamId). + Where( + sq.And{ + sq.Eq{"ChannelId": channelId}, + sq.Expr("EXISTS(SELECT 1 FROM Teams WHERE Id = ?)", teamId), + }) + + _, err := s.GetMaster().ExecBuilder(query) + if err != nil { + return errors.Wrapf(err, "failed to update threads team id for channel id=%s", channelId) + } + + return nil +} diff --git a/server/channels/store/store.go b/server/channels/store/store.go index eec2b9c4f1..7e77e80698 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -360,6 +360,7 @@ type ThreadStore interface { SaveMultipleMemberships(memberships []*model.ThreadMembership) ([]*model.ThreadMembership, error) MaintainMultipleFromImport(memberships []*model.ThreadMembership) ([]*model.ThreadMembership, error) + UpdateTeamIdForChannelThreads(channelId, teamId string) error } type PostStore interface { diff --git a/server/channels/store/storetest/mocks/ThreadStore.go b/server/channels/store/storetest/mocks/ThreadStore.go index 8dde1db6c6..85e23756da 100644 --- a/server/channels/store/storetest/mocks/ThreadStore.go +++ b/server/channels/store/storetest/mocks/ThreadStore.go @@ -721,6 +721,24 @@ func (_m *ThreadStore) UpdateMembership(membership *model.ThreadMembership) (*mo return r0, r1 } +// UpdateTeamIdForChannelThreads provides a mock function with given fields: channelId, teamId +func (_m *ThreadStore) UpdateTeamIdForChannelThreads(channelId string, teamId string) error { + ret := _m.Called(channelId, teamId) + + if len(ret) == 0 { + panic("no return value specified for UpdateTeamIdForChannelThreads") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(channelId, teamId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewThreadStore creates a new instance of ThreadStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewThreadStore(t interface { diff --git a/server/channels/store/storetest/thread_store.go b/server/channels/store/storetest/thread_store.go index 2a7b55d4d7..acc920532b 100644 --- a/server/channels/store/storetest/thread_store.go +++ b/server/channels/store/storetest/thread_store.go @@ -32,6 +32,7 @@ func TestThreadStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) t.Run("DeleteMembershipsForChannel", func(t *testing.T) { testDeleteMembershipsForChannel(t, rctx, ss) }) t.Run("SaveMultipleMemberships", func(t *testing.T) { testSaveMultipleMemberships(t, ss) }) t.Run("MaintainMultipleFromImport", func(t *testing.T) { testMaintainMultipleFromImport(t, rctx, ss) }) + t.Run("UpdateTeamIdForChannelThreads", func(t *testing.T) { testUpdateTeamIdForChannelThreads(t, rctx, ss) }) } func testThreadStorePopulation(t *testing.T, rctx request.CTX, ss store.Store) { @@ -2016,3 +2017,113 @@ func testMaintainMultipleFromImport(t *testing.T, rctx request.CTX, ss store.Sto require.NoError(t, err) }) } + +func testUpdateTeamIdForChannelThreads(t *testing.T, rctx request.CTX, ss store.Store) { + createThreadMembership := func(userID, postID string, following bool) (*model.ThreadMembership, func()) { + t.Helper() + opts := store.ThreadMembershipOpts{ + Following: following, + IncrementMentions: false, + UpdateFollowing: true, + UpdateViewedTimestamp: false, + UpdateParticipants: true, + } + mem, err := ss.Thread().MaintainMembership(userID, postID, opts) + require.NoError(t, err) + + return mem, func() { + err := ss.Thread().DeleteMembershipForUser(userID, postID) + require.NoError(t, err) + } + } + + postingUserID := model.NewId() + + team1, err := ss.Team().Save(&model.Team{ + DisplayName: "DisplayName", + Name: "team" + model.NewId(), + Email: MakeEmail(), + Type: model.TeamOpen, + }) + require.NoError(t, err) + + team2, err := ss.Team().Save(&model.Team{ + DisplayName: "DisplayNameTwo", + Name: "team" + model.NewId(), + Email: MakeEmail(), + Type: model.TeamOpen, + }) + require.NoError(t, err) + + channel1, err := ss.Channel().Save(rctx, &model.Channel{ + TeamId: team1.Id, + DisplayName: "DisplayName", + Name: "channel1" + model.NewId(), + Type: model.ChannelTypeOpen, + }, -1) + require.NoError(t, err) + + rootPost1, err := ss.Post().Save(rctx, &model.Post{ + ChannelId: channel1.Id, + UserId: postingUserID, + Message: model.NewRandomString(10), + }) + require.NoError(t, err) + + _, err = ss.Post().Save(rctx, &model.Post{ + ChannelId: channel1.Id, + UserId: postingUserID, + Message: model.NewRandomString(10), + RootId: rootPost1.Id, + }) + require.NoError(t, err) + + t.Run("Should move threads to the new team", func(t *testing.T) { + userA, err := ss.User().Save(request.TestContext(t), &model.User{ + Username: model.NewId(), + Email: MakeEmail(), + Password: model.NewId(), + }) + require.NoError(t, err) + + _, clean := createThreadMembership(userA.Id, rootPost1.Id, true) + defer clean() + + err = ss.Thread().UpdateTeamIdForChannelThreads(channel1.Id, team2.Id) + require.NoError(t, err) + + defer func() { + err = ss.Thread().UpdateTeamIdForChannelThreads(channel1.Id, team1.Id) + require.NoError(t, err) + }() + + threads, err := ss.Thread().GetThreadsForUser(userA.Id, team2.Id, model.GetUserThreadsOpts{}) + require.NoError(t, err) + require.Len(t, threads, 1) + }) + + t.Run("Should not move threads to a non existent team", func(t *testing.T) { + userA, err := ss.User().Save(request.TestContext(t), &model.User{ + Username: model.NewId(), + Email: MakeEmail(), + Password: model.NewId(), + }) + require.NoError(t, err) + + newTeamID := model.NewId() + + _, clean := createThreadMembership(userA.Id, rootPost1.Id, true) + t.Cleanup(clean) + + err = ss.Thread().UpdateTeamIdForChannelThreads(channel1.Id, newTeamID) + require.NoError(t, err) + + threads, err := ss.Thread().GetThreadsForUser(userA.Id, newTeamID, model.GetUserThreadsOpts{}) + require.NoError(t, err) + require.Len(t, threads, 0) + + threads, err = ss.Thread().GetThreadsForUser(userA.Id, team1.Id, model.GetUserThreadsOpts{}) + require.NoError(t, err) + require.Len(t, threads, 1) + }) +} diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 4c6b2c262f..51d8d95975 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -10639,6 +10639,22 @@ func (s *TimerLayerThreadStore) UpdateMembership(membership *model.ThreadMembers return result, err } +func (s *TimerLayerThreadStore) UpdateTeamIdForChannelThreads(channelId string, teamId string) error { + start := time.Now() + + err := s.ThreadStore.UpdateTeamIdForChannelThreads(channelId, teamId) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ThreadStore.UpdateTeamIdForChannelThreads", success, elapsed) + } + return err +} + func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) { start := time.Now()