mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-45009: Delete ThreadMemberships from "left" channels (#22559)
* MM-50550: Filter out threads from "left" channels v2 Currently leaving a channel doesn't affect the thread memberships of that user/channel combination. This PR aims to filter out all threads from those channels for the user. Adds a DeleteAt column in the ThreadMemberships table, and filter out all thread memberships that are "deleted". Each time a user leaves a channel all thread memberships are going to be marked as deleted, and when a user joins a channel again all those existing thread memberships will be re-instantiated. Adds a migration to mark all existing thread memberships as deleted depending on whether there exists a channel membership for that channel/user. * Added migration files into list * Fixes tests * Fixes case where DeleteAt would be null * Guard thread API endpoints with appropriate perms * Deletes ThreadMembership rows upon leaving channel * Minor style changes * Use NoTranslation error * Refactors tests * Adds API tests to assert permissions on Team * Adds tests, and fixes migrations * Fixes test description * Fix test * Removes check on DM/GMs * Change the MySQL query in the migration --------- Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
parent
c34a50a6c7
commit
a24111f9bd
@ -3106,6 +3106,10 @@ func getThreadForUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.SetPermissionError(model.PermissionEditOtherUsers)
|
||||
return
|
||||
}
|
||||
if !c.App.SessionHasPermissionToChannelByPost(*c.AppContext.Session(), c.Params.ThreadId, model.PermissionReadChannel) {
|
||||
c.SetPermissionError(model.PermissionReadChannel)
|
||||
return
|
||||
}
|
||||
extendedStr := r.URL.Query().Get("extended")
|
||||
extended, _ := strconv.ParseBool(extendedStr)
|
||||
|
||||
@ -3136,6 +3140,10 @@ func getThreadsForUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.SetPermissionError(model.PermissionEditOtherUsers)
|
||||
return
|
||||
}
|
||||
if !c.App.SessionHasPermissionToTeam(*c.AppContext.Session(), c.Params.TeamId, model.PermissionViewTeam) {
|
||||
c.SetPermissionError(model.PermissionViewTeam)
|
||||
return
|
||||
}
|
||||
|
||||
options := model.GetUserThreadsOpts{
|
||||
Since: 0,
|
||||
@ -3213,6 +3221,10 @@ func updateReadStateThreadByUser(c *Context, w http.ResponseWriter, r *http.Requ
|
||||
c.SetPermissionError(model.PermissionEditOtherUsers)
|
||||
return
|
||||
}
|
||||
if !c.App.SessionHasPermissionToChannelByPost(*c.AppContext.Session(), c.Params.ThreadId, model.PermissionReadChannel) {
|
||||
c.SetPermissionError(model.PermissionReadChannel)
|
||||
return
|
||||
}
|
||||
|
||||
thread, err := c.App.UpdateThreadReadForUser(c.AppContext, c.AppContext.Session().Id, c.Params.UserId, c.Params.TeamId, c.Params.ThreadId, c.Params.Timestamp)
|
||||
if err != nil {
|
||||
@ -3279,6 +3291,10 @@ func unfollowThreadByUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.SetPermissionError(model.PermissionEditOtherUsers)
|
||||
return
|
||||
}
|
||||
if !c.App.SessionHasPermissionToChannelByPost(*c.AppContext.Session(), c.Params.ThreadId, model.PermissionReadChannel) {
|
||||
c.SetPermissionError(model.PermissionReadChannel)
|
||||
return
|
||||
}
|
||||
|
||||
err := c.App.UpdateThreadFollowForUser(c.Params.UserId, c.Params.TeamId, c.Params.ThreadId, false)
|
||||
if err != nil {
|
||||
@ -3338,6 +3354,10 @@ func updateReadStateAllThreadsByUser(c *Context, w http.ResponseWriter, r *http.
|
||||
c.SetPermissionError(model.PermissionEditOtherUsers)
|
||||
return
|
||||
}
|
||||
if !c.App.SessionHasPermissionToTeam(*c.AppContext.Session(), c.Params.TeamId, model.PermissionViewTeam) {
|
||||
c.SetPermissionError(model.PermissionViewTeam)
|
||||
return
|
||||
}
|
||||
|
||||
err := c.App.UpdateThreadsReadForUser(c.Params.UserId, c.Params.TeamId)
|
||||
if err != nil {
|
||||
|
@ -6360,6 +6360,15 @@ func TestGetThreadsForUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uss.TotalUnreadThreads, int64(2))
|
||||
})
|
||||
|
||||
t.Run("should error when not a team member", func(t *testing.T) {
|
||||
th.UnlinkUserFromTeam(th.BasicUser, th.BasicTeam)
|
||||
defer th.LinkUserToTeam(th.BasicUser, th.BasicTeam)
|
||||
|
||||
_, resp, err := th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{})
|
||||
require.Error(t, err)
|
||||
CheckForbiddenStatus(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestThreadSocketEvents(t *testing.T) {
|
||||
@ -6855,52 +6864,64 @@ func TestSingleThreadGet(t *testing.T) {
|
||||
})
|
||||
|
||||
client := th.Client
|
||||
defer th.App.Srv().Store().Post().PermanentDeleteByUser(th.BasicUser.Id)
|
||||
defer th.App.Srv().Store().Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
|
||||
|
||||
// create a post by regular user
|
||||
rpost, _ := postAndCheck(t, client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testMsg"})
|
||||
// reply with another
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply", RootId: rpost.Id})
|
||||
t.Run("get single thread", func(t *testing.T) {
|
||||
defer th.App.Srv().Store().Post().PermanentDeleteByUser(th.BasicUser.Id)
|
||||
defer th.App.Srv().Store().Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
|
||||
|
||||
// create another thread to check that we are not returning it by mistake
|
||||
rpost2, _ := postAndCheck(t, client, &model.Post{
|
||||
ChannelId: th.BasicChannel2.Id,
|
||||
Message: "testMsg2",
|
||||
Metadata: &model.PostMetadata{
|
||||
Priority: &model.PostPriority{
|
||||
Priority: model.NewString(model.PostPriorityUrgent),
|
||||
// create a post by regular user
|
||||
rpost, _ := postAndCheck(t, client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testMsg"})
|
||||
// reply with another
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply", RootId: rpost.Id})
|
||||
|
||||
// create another thread to check that we are not returning it by mistake
|
||||
rpost2, _ := postAndCheck(t, client, &model.Post{
|
||||
ChannelId: th.BasicChannel2.Id,
|
||||
Message: "testMsg2",
|
||||
Metadata: &model.PostMetadata{
|
||||
Priority: &model.PostPriority{
|
||||
Priority: model.NewString(model.PostPriorityUrgent),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel2.Id, Message: "testReply", RootId: rpost2.Id})
|
||||
})
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel2.Id, Message: "testReply", RootId: rpost2.Id})
|
||||
|
||||
// regular user should have two threads with 3 replies total
|
||||
threads, _ := checkThreadListReplies(t, th, th.Client, th.BasicUser.Id, 2, 2, nil)
|
||||
// regular user should have two threads with 3 replies total
|
||||
threads, _ := checkThreadListReplies(t, th, th.Client, th.BasicUser.Id, 2, 2, nil)
|
||||
|
||||
tr, _, err := th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tr)
|
||||
require.Equal(t, threads.Threads[0].PostId, tr.PostId)
|
||||
require.Empty(t, tr.Participants[0].Username)
|
||||
tr, _, err := th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tr)
|
||||
require.Equal(t, threads.Threads[0].PostId, tr.PostId)
|
||||
require.Empty(t, tr.Participants[0].Username)
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.PostPriority = false
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.PostPriority = false
|
||||
})
|
||||
|
||||
tr, _, err = th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, true)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tr.Participants[0].Username)
|
||||
require.Equal(t, false, tr.IsUrgent)
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.PostPriority = true
|
||||
cfg.FeatureFlags.PostPriority = true
|
||||
})
|
||||
|
||||
tr, _, err = th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, tr.IsUrgent)
|
||||
})
|
||||
|
||||
tr, _, err = th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, true)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tr.Participants[0].Username)
|
||||
require.Equal(t, false, tr.IsUrgent)
|
||||
t.Run("should error when not a team member", func(t *testing.T) {
|
||||
th.UnlinkUserFromTeam(th.BasicUser, th.BasicTeam)
|
||||
defer th.LinkUserToTeam(th.BasicUser, th.BasicTeam)
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.PostPriority = true
|
||||
cfg.FeatureFlags.PostPriority = true
|
||||
_, resp, err := th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, model.NewId(), false)
|
||||
require.Error(t, err)
|
||||
CheckForbiddenStatus(t, resp)
|
||||
})
|
||||
|
||||
tr, _, err = th.Client.GetUserThread(th.BasicUser.Id, th.BasicTeam.Id, threads.Threads[0].PostId, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, tr.IsUrgent)
|
||||
}
|
||||
|
||||
func TestMaintainUnreadMentionsInThread(t *testing.T) {
|
||||
@ -7072,6 +7093,23 @@ func TestReadThreads(t *testing.T) {
|
||||
|
||||
checkThreadListReplies(t, th, th.Client, th.BasicUser.Id, 1, 1, nil)
|
||||
})
|
||||
|
||||
t.Run("should error when not a team member", func(t *testing.T) {
|
||||
th.UnlinkUserFromTeam(th.BasicUser, th.BasicTeam)
|
||||
defer th.LinkUserToTeam(th.BasicUser, th.BasicTeam)
|
||||
|
||||
_, resp, err := th.Client.UpdateThreadReadForUser(th.BasicUser.Id, th.BasicTeam.Id, model.NewId(), model.GetMillis())
|
||||
require.Error(t, err)
|
||||
CheckForbiddenStatus(t, resp)
|
||||
|
||||
_, resp, err = th.Client.SetThreadUnreadByPostId(th.BasicUser.Id, th.BasicTeam.Id, model.NewId(), model.NewId())
|
||||
require.Error(t, err)
|
||||
CheckForbiddenStatus(t, resp)
|
||||
|
||||
resp, err = th.Client.UpdateThreadsReadForUser(th.BasicUser.Id, th.BasicTeam.Id)
|
||||
require.Error(t, err)
|
||||
CheckForbiddenStatus(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMarkThreadUnreadMentionCount(t *testing.T) {
|
||||
|
@ -2518,6 +2518,9 @@ func (a *App) removeUserFromChannel(c request.CTX, userIDToRemove string, remove
|
||||
if err := a.Srv().Store().ChannelMemberHistory().LogLeaveEvent(userIDToRemove, channel.Id, model.GetMillis()); err != nil {
|
||||
return model.NewAppError("removeUserFromChannel", "app.channel_member_history.log_leave_event.internal_error", nil, "", http.StatusInternalServerError).Wrap(err)
|
||||
}
|
||||
if err := a.Srv().Store().Thread().DeleteMembershipsForChannel(userIDToRemove, channel.Id); err != nil {
|
||||
return model.NewAppError("removeUserFromChannel", model.NoTranslation, nil, "failed to delete threadmemberships upon leaving channel", http.StatusInternalServerError).Wrap(err)
|
||||
}
|
||||
|
||||
if isGuest {
|
||||
currentMembers, err := a.GetChannelMembersForUser(c, channel.TeamId, userIDToRemove)
|
||||
|
@ -609,6 +609,85 @@ func TestLeaveDefaultChannel(t *testing.T) {
|
||||
_, err = th.App.GetChannelMember(th.Context, townSquare.Id, guest.Id)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Trying to leave the default channel should not delete thread memberships", func(t *testing.T) {
|
||||
post := &model.Post{
|
||||
ChannelId: townSquare.Id,
|
||||
Message: "root post",
|
||||
UserId: th.BasicUser.Id,
|
||||
}
|
||||
rpost, err := th.App.CreatePost(th.Context, post, th.BasicChannel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
reply := &model.Post{
|
||||
ChannelId: townSquare.Id,
|
||||
Message: "reply post",
|
||||
UserId: th.BasicUser.Id,
|
||||
RootId: rpost.Id,
|
||||
}
|
||||
_, err = th.App.CreatePost(th.Context, reply, th.BasicChannel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
threads, err := th.App.GetThreadsForUser(th.BasicUser.Id, townSquare.TeamId, model.GetUserThreadsOpts{})
|
||||
require.Nil(t, err)
|
||||
require.Len(t, threads.Threads, 1)
|
||||
|
||||
err = th.App.LeaveChannel(th.Context, townSquare.Id, th.BasicUser.Id)
|
||||
assert.NotNil(t, err, "It should fail to remove a regular user from the default channel")
|
||||
assert.Equal(t, err.Id, "api.channel.remove.default.app_error")
|
||||
|
||||
threads, err = th.App.GetThreadsForUser(th.BasicUser.Id, townSquare.TeamId, model.GetUserThreadsOpts{})
|
||||
require.Nil(t, err)
|
||||
require.Len(t, threads.Threads, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLeaveChannel(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
createThread := func(channel *model.Channel) (rpost *model.Post) {
|
||||
t.Helper()
|
||||
post := &model.Post{
|
||||
ChannelId: channel.Id,
|
||||
Message: "root post",
|
||||
UserId: th.BasicUser.Id,
|
||||
}
|
||||
|
||||
rpost, err := th.App.CreatePost(th.Context, post, th.BasicChannel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
reply := &model.Post{
|
||||
ChannelId: channel.Id,
|
||||
Message: "reply post",
|
||||
UserId: th.BasicUser.Id,
|
||||
RootId: rpost.Id,
|
||||
}
|
||||
_, err = th.App.CreatePost(th.Context, reply, th.BasicChannel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
return rpost
|
||||
}
|
||||
|
||||
t.Run("thread memberships are deleted", func(t *testing.T) {
|
||||
createThread(th.BasicChannel)
|
||||
channel2 := th.createChannel(th.Context, th.BasicTeam, model.ChannelTypeOpen)
|
||||
createThread(channel2)
|
||||
|
||||
threads, err := th.App.GetThreadsForUser(th.BasicUser.Id, th.BasicChannel.TeamId, model.GetUserThreadsOpts{})
|
||||
require.Nil(t, err)
|
||||
require.Len(t, threads.Threads, 2)
|
||||
|
||||
err = th.App.LeaveChannel(th.Context, th.BasicChannel.Id, th.BasicUser.Id)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = th.App.GetChannelMember(th.Context, th.BasicChannel.Id, th.BasicUser.Id)
|
||||
require.NotNil(t, err, "It should remove channel membership")
|
||||
|
||||
threads, err = th.App.GetThreadsForUser(th.BasicUser.Id, th.BasicChannel.TeamId, model.GetUserThreadsOpts{})
|
||||
require.Nil(t, err)
|
||||
require.Len(t, threads.Threads, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLeaveLastChannel(t *testing.T) {
|
||||
|
@ -212,6 +212,8 @@ channels/db/migrations/mysql/000105_remove_tokens.down.sql
|
||||
channels/db/migrations/mysql/000105_remove_tokens.up.sql
|
||||
channels/db/migrations/mysql/000106_fileinfo_channelid.down.sql
|
||||
channels/db/migrations/mysql/000106_fileinfo_channelid.up.sql
|
||||
channels/db/migrations/mysql/000107_threadmemberships_cleanup.down.sql
|
||||
channels/db/migrations/mysql/000107_threadmemberships_cleanup.up.sql
|
||||
channels/db/migrations/postgres/000001_create_teams.down.sql
|
||||
channels/db/migrations/postgres/000001_create_teams.up.sql
|
||||
channels/db/migrations/postgres/000002_create_team_members.down.sql
|
||||
@ -424,3 +426,5 @@ channels/db/migrations/postgres/000105_remove_tokens.down.sql
|
||||
channels/db/migrations/postgres/000105_remove_tokens.up.sql
|
||||
channels/db/migrations/postgres/000106_fileinfo_channelid.down.sql
|
||||
channels/db/migrations/postgres/000106_fileinfo_channelid.up.sql
|
||||
channels/db/migrations/postgres/000107_threadmemberships_cleanup.down.sql
|
||||
channels/db/migrations/postgres/000107_threadmemberships_cleanup.up.sql
|
||||
|
@ -0,0 +1 @@
|
||||
-- Skipping it because the forward migrations are destructive
|
@ -0,0 +1,5 @@
|
||||
DELETE FROM
|
||||
tm USING ThreadMemberships AS tm
|
||||
JOIN Threads ON Threads.PostId = tm.PostId
|
||||
WHERE
|
||||
(tm.UserId, Threads.ChannelId) NOT IN (SELECT UserId, ChannelId FROM ChannelMembers);
|
@ -0,0 +1 @@
|
||||
-- Skipping it because the forward migrations are destructive
|
@ -0,0 +1,12 @@
|
||||
DELETE FROM threadmemberships WHERE (postid, userid) IN (
|
||||
SELECT
|
||||
threadmemberships.postid,
|
||||
threadmemberships.userid
|
||||
FROM
|
||||
threadmemberships
|
||||
JOIN threads ON threads.postid = threadmemberships.postid
|
||||
LEFT JOIN channelmembers ON channelmembers.userid = threadmemberships.userid
|
||||
AND threads.channelid = channelmembers.channelid
|
||||
WHERE
|
||||
channelmembers.channelid IS NULL
|
||||
);
|
@ -10123,6 +10123,24 @@ func (s *OpenTracingLayerThreadStore) DeleteMembershipForUser(userId string, pos
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerThreadStore) DeleteMembershipsForChannel(userID string, channelID string) error {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.DeleteMembershipsForChannel")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
defer func() {
|
||||
s.Root.Store.SetContext(origCtx)
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
err := s.ThreadStore.DeleteMembershipsForChannel(userID, channelID)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerThreadStore) DeleteOrphanedRows(limit int) (int64, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.DeleteOrphanedRows")
|
||||
|
@ -11563,6 +11563,27 @@ func (s *RetryLayerThreadStore) DeleteMembershipForUser(userId string, postID st
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerThreadStore) DeleteMembershipsForChannel(userID string, channelID string) error {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
err := s.ThreadStore.DeleteMembershipsForChannel(userID, channelID)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !isRepeatableError(err) {
|
||||
return err
|
||||
}
|
||||
tries++
|
||||
if tries >= 3 {
|
||||
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
|
||||
return err
|
||||
}
|
||||
timepkg.Sleep(100 * timepkg.Millisecond)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerThreadStore) DeleteOrphanedRows(limit int) (int64, error) {
|
||||
|
||||
tries := 0
|
||||
|
@ -688,6 +688,28 @@ func (s *SqlThreadStore) UpdateMembership(membership *model.ThreadMembership) (*
|
||||
return s.updateMembership(s.GetMasterX(), membership)
|
||||
}
|
||||
|
||||
func (s *SqlThreadStore) DeleteMembershipsForChannel(userID, channelID string) error {
|
||||
subQuery := s.getSubQueryBuilder().
|
||||
Select("1").
|
||||
From("Threads").
|
||||
Where(sq.And{
|
||||
sq.Expr("Threads.PostId = ThreadMemberships.PostId"),
|
||||
sq.Eq{"Threads.ChannelId": channelID},
|
||||
})
|
||||
|
||||
query := s.getQueryBuilder().
|
||||
Delete("ThreadMemberships").
|
||||
Where(sq.Eq{"UserId": userID}).
|
||||
Where(sq.Expr("EXISTS (?)", subQuery))
|
||||
|
||||
_, err := s.GetMasterX().ExecBuilder(query)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to remove thread memberships with userid=%s channelid=%s", userID, channelID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlThreadStore) updateMembership(ex sqlxExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) {
|
||||
query := s.getQueryBuilder().
|
||||
Update("ThreadMemberships").
|
||||
@ -712,7 +734,14 @@ func (s *SqlThreadStore) GetMembershipsForUser(userId, teamId string) ([]*model.
|
||||
memberships := []*model.ThreadMembership{}
|
||||
|
||||
query := s.getQueryBuilder().
|
||||
Select("ThreadMemberships.*").
|
||||
Select(
|
||||
"ThreadMemberships.PostId",
|
||||
"ThreadMemberships.UserId",
|
||||
"ThreadMemberships.Following",
|
||||
"ThreadMemberships.LastUpdated",
|
||||
"ThreadMemberships.LastViewed",
|
||||
"ThreadMemberships.UnreadMentions",
|
||||
).
|
||||
Join("Threads ON Threads.PostId = ThreadMemberships.PostId").
|
||||
From("ThreadMemberships").
|
||||
Where(sq.Or{sq.Eq{"Threads.ThreadTeamId": teamId}, sq.Eq{"Threads.ThreadTeamId": ""}}).
|
||||
@ -732,7 +761,14 @@ func (s *SqlThreadStore) GetMembershipForUser(userId, postId string) (*model.Thr
|
||||
func (s *SqlThreadStore) getMembershipForUser(ex sqlxExecutor, userId, postId string) (*model.ThreadMembership, error) {
|
||||
var membership model.ThreadMembership
|
||||
query := s.getQueryBuilder().
|
||||
Select("*").
|
||||
Select(
|
||||
"PostId",
|
||||
"UserId",
|
||||
"Following",
|
||||
"LastUpdated",
|
||||
"LastViewed",
|
||||
"UnreadMentions",
|
||||
).
|
||||
From("ThreadMemberships").
|
||||
Where(sq.And{
|
||||
sq.Eq{"PostId": postId},
|
||||
|
@ -344,6 +344,7 @@ type ThreadStore interface {
|
||||
PermanentDeleteBatchThreadMembershipsForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error)
|
||||
DeleteOrphanedRows(limit int) (deleted int64, err error)
|
||||
GetThreadUnreadReplyCount(threadMembership *model.ThreadMembership) (int64, error)
|
||||
DeleteMembershipsForChannel(userID, channelID string) error
|
||||
|
||||
// Insights - threads
|
||||
GetTopThreadsForTeamSince(teamID string, userID string, since int64, offset int, limit int) (*model.TopThreadList, error)
|
||||
|
@ -29,6 +29,20 @@ func (_m *ThreadStore) DeleteMembershipForUser(userId string, postID string) err
|
||||
return r0
|
||||
}
|
||||
|
||||
// DeleteMembershipsForChannel provides a mock function with given fields: userID, channelID
|
||||
func (_m *ThreadStore) DeleteMembershipsForChannel(userID string, channelID string) error {
|
||||
ret := _m.Called(userID, channelID)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, string) error); ok {
|
||||
r0 = rf(userID, channelID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DeleteOrphanedRows provides a mock function with given fields: limit
|
||||
func (_m *ThreadStore) DeleteOrphanedRows(limit int) (int64, error) {
|
||||
ret := _m.Called(limit)
|
||||
|
@ -29,6 +29,7 @@ func TestThreadStore(t *testing.T, ss store.Store, s SqlStore) {
|
||||
t.Run("MarkAllAsReadByChannels", func(t *testing.T) { testMarkAllAsReadByChannels(t, ss) })
|
||||
t.Run("GetTopThreads", func(t *testing.T) { testGetTopThreads(t, ss) })
|
||||
t.Run("MarkAllAsReadByTeam", func(t *testing.T) { testMarkAllAsReadByTeam(t, ss) })
|
||||
t.Run("DeleteMembershipsForChannel", func(t *testing.T) { testDeleteMembershipsForChannel(t, ss) })
|
||||
}
|
||||
|
||||
func testThreadStorePopulation(t *testing.T, ss store.Store) {
|
||||
@ -1914,3 +1915,121 @@ func testMarkAllAsReadByTeam(t *testing.T, ss store.Store) {
|
||||
assertThreadReplyCount(t, userBID, team2.Id, 1, "expected 1 unread message in team2 for userB")
|
||||
})
|
||||
}
|
||||
|
||||
func testDeleteMembershipsForChannel(t *testing.T, ss store.Store) {
|
||||
createThreadMembership := func(userID, postID string) (*model.ThreadMembership, func()) {
|
||||
t.Helper()
|
||||
opts := store.ThreadMembershipOpts{
|
||||
Following: true,
|
||||
IncrementMentions: false,
|
||||
UpdateFollowing: true,
|
||||
UpdateViewedTimestamp: false,
|
||||
UpdateParticipants: false,
|
||||
}
|
||||
mem, err := ss.Thread().MaintainMembership(userID, postID, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
return mem, func() {
|
||||
err := ss.Thread().DeleteMembershipForUser(userID, postID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
postingUserID := model.NewId()
|
||||
userAID := model.NewId()
|
||||
userBID := model.NewId()
|
||||
|
||||
team, err := ss.Team().Save(&model.Team{
|
||||
DisplayName: "DisplayName",
|
||||
Name: "team" + model.NewId(),
|
||||
Email: MakeEmail(),
|
||||
Type: model.TeamOpen,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
channel1, err := ss.Channel().Save(&model.Channel{
|
||||
TeamId: team.Id,
|
||||
DisplayName: "DisplayName",
|
||||
Name: "channel1" + model.NewId(),
|
||||
Type: model.ChannelTypeOpen,
|
||||
}, -1)
|
||||
require.NoError(t, err)
|
||||
channel2, err := ss.Channel().Save(&model.Channel{
|
||||
TeamId: team.Id,
|
||||
DisplayName: "DisplayName2",
|
||||
Name: "channel2" + model.NewId(),
|
||||
Type: model.ChannelTypeOpen,
|
||||
}, -1)
|
||||
require.NoError(t, err)
|
||||
|
||||
rootPost1, err := ss.Post().Save(&model.Post{
|
||||
ChannelId: channel1.Id,
|
||||
UserId: postingUserID,
|
||||
Message: model.NewRandomString(10),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ss.Post().Save(&model.Post{
|
||||
ChannelId: channel1.Id,
|
||||
UserId: postingUserID,
|
||||
Message: model.NewRandomString(10),
|
||||
RootId: rootPost1.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rootPost2, err := ss.Post().Save(&model.Post{
|
||||
ChannelId: channel2.Id,
|
||||
UserId: postingUserID,
|
||||
Message: model.NewRandomString(10),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ss.Post().Save(&model.Post{
|
||||
ChannelId: channel2.Id,
|
||||
UserId: postingUserID,
|
||||
Message: model.NewRandomString(10),
|
||||
RootId: rootPost2.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return memberships for user", func(t *testing.T) {
|
||||
memA1, cleanupA1 := createThreadMembership(userAID, rootPost1.Id)
|
||||
defer cleanupA1()
|
||||
memA2, cleanupA2 := createThreadMembership(userAID, rootPost2.Id)
|
||||
defer cleanupA2()
|
||||
|
||||
membershipsA, err := ss.Thread().GetMembershipsForUser(userAID, team.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, membershipsA, 2)
|
||||
require.ElementsMatch(t, []*model.ThreadMembership{memA1, memA2}, membershipsA)
|
||||
})
|
||||
|
||||
t.Run("should delete memberships for user for channel", func(t *testing.T) {
|
||||
_, cleanupA1 := createThreadMembership(userAID, rootPost1.Id)
|
||||
defer cleanupA1()
|
||||
memA2, cleanupA2 := createThreadMembership(userAID, rootPost2.Id)
|
||||
defer cleanupA2()
|
||||
|
||||
ss.Thread().DeleteMembershipsForChannel(userAID, channel1.Id)
|
||||
membershipsA, err := ss.Thread().GetMembershipsForUser(userAID, team.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, membershipsA, 1)
|
||||
require.ElementsMatch(t, []*model.ThreadMembership{memA2}, membershipsA)
|
||||
})
|
||||
|
||||
t.Run("deleting memberships for channel for userA should not affect userB", func(t *testing.T) {
|
||||
_, cleanupA1 := createThreadMembership(userAID, rootPost1.Id)
|
||||
defer cleanupA1()
|
||||
_, cleanupA2 := createThreadMembership(userAID, rootPost2.Id)
|
||||
defer cleanupA2()
|
||||
memB1, cleanupB2 := createThreadMembership(userBID, rootPost1.Id)
|
||||
defer cleanupB2()
|
||||
|
||||
membershipsB, err := ss.Thread().GetMembershipsForUser(userBID, team.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, membershipsB, 1)
|
||||
require.ElementsMatch(t, []*model.ThreadMembership{memB1}, membershipsB)
|
||||
})
|
||||
}
|
||||
|
@ -9112,6 +9112,22 @@ func (s *TimerLayerThreadStore) DeleteMembershipForUser(userId string, postID st
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *TimerLayerThreadStore) DeleteMembershipsForChannel(userID string, channelID string) error {
|
||||
start := time.Now()
|
||||
|
||||
err := s.ThreadStore.DeleteMembershipsForChannel(userID, channelID)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
success := "false"
|
||||
if err == nil {
|
||||
success = "true"
|
||||
}
|
||||
s.Root.Metrics.ObserveStoreMethodDuration("ThreadStore.DeleteMembershipsForChannel", success, elapsed)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *TimerLayerThreadStore) DeleteOrphanedRows(limit int) (int64, error) {
|
||||
start := time.Now()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user