diff --git a/api4/user_test.go b/api4/user_test.go index 75af85c9e8..40076927e5 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -5515,6 +5515,51 @@ func TestGetThreadsForUser(t *testing.T) { require.NotNil(t, uss3.Threads) require.Len(t, uss3.Threads, 0) }) + + t.Run("editing or reacting to reply post does not make thread unread", func(t *testing.T) { + Client := th.Client + + rootPost, _ := postAndCheck(t, Client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "root post"}) + replyPost, _ := postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel.Id, Message: "reply post", RootId: rootPost.Id}) + uss, resp := th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ + Deleted: false, + }) + CheckNoError(t, resp) + require.Equal(t, uss.TotalUnreadThreads, int64(1)) + require.Equal(t, uss.Threads[0].PostId, rootPost.Id) + + _, resp = th.Client.UpdateThreadReadForUser(th.BasicUser.Id, th.BasicChannel.TeamId, rootPost.Id, model.GetMillis()) + CheckNoError(t, resp) + uss, resp = th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ + Deleted: false, + }) + CheckNoError(t, resp) + require.Equal(t, uss.TotalUnreadThreads, int64(0)) + + // edit post + editedReplyPostMessage := "edited " + replyPost.Message + _, resp = th.SystemAdminClient.PatchPost(replyPost.Id, &model.PostPatch{Message: &editedReplyPostMessage}) + CheckNoError(t, resp) + uss, resp = th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ + Deleted: false, + }) + CheckNoError(t, resp) + require.Equal(t, uss.TotalUnreadThreads, int64(0)) + + // react to post + reaction := &model.Reaction{ + UserId: th.SystemAdminUser.Id, + PostId: replyPost.Id, + EmojiName: "smile", + } + _, resp = th.SystemAdminClient.SaveReaction(reaction) + CheckNoError(t, resp) + uss, resp = th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ + Deleted: false, + }) + CheckNoError(t, resp) + require.Equal(t, uss.TotalUnreadThreads, int64(0)) + }) } func TestThreadSocketEvents(t *testing.T) { diff --git a/store/sqlstore/thread_store.go b/store/sqlstore/thread_store.go index 7e57b5f4a7..66f0336ee7 100644 --- a/store/sqlstore/thread_store.go +++ b/store/sqlstore/thread_store.go @@ -121,7 +121,7 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get model.Post } - unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.UpdateAt >= ThreadMemberships.LastViewed" + unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.CreateAt >= ThreadMemberships.LastViewed" fetchConditions := sq.And{ sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, sq.Eq{"ThreadMemberships.UserId": userId}, @@ -147,7 +147,7 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get LeftJoin("ThreadMemberships ON Posts.RootId = ThreadMemberships.PostId"). LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). Where(fetchConditions). - Where("Posts.UpdateAt >= ThreadMemberships.LastViewed").ToSql() + Where("Posts.CreateAt >= ThreadMemberships.LastViewed").ToSql() totalUnreadThreads, err := s.GetMaster().SelectInt(repliesQuery, repliesQueryArgs...) totalUnreadThreadsChan <- store.StoreResult{Data: totalUnreadThreads, NErr: errors.Wrapf(err, "failed to get count unread on threads for user id=%s", userId)} @@ -334,7 +334,7 @@ func (s *SqlThreadStore) GetThreadForUser(userId, teamId, threadId string, exten model.Post } - unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.UpdateAt >= ThreadMemberships.LastViewed AND Posts.DeleteAt=0" + unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.CreateAt >= ThreadMemberships.LastViewed AND Posts.DeleteAt=0" fetchConditions := sq.And{ sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, sq.Eq{"ThreadMemberships.UserId": userId}, diff --git a/store/storetest/thread_store.go b/store/storetest/thread_store.go index 073b102a34..d9e89a89d2 100644 --- a/store/storetest/thread_store.go +++ b/store/storetest/thread_store.go @@ -329,4 +329,29 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) { return m2.LastUpdated < m.LastUpdated }, time.Second, 10*time.Millisecond) }) + + t.Run("Updating post does not make thread unread", func(t *testing.T) { + newPosts := makeSomePosts() + m, err := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, true, false) + require.NoError(t, err) + th, err := ss.Thread().GetThreadForUser(newPosts[0].UserId, "", newPosts[0].Id, false) + require.NoError(t, err) + require.Equal(t, int64(2), th.UnreadReplies) + + m.LastViewed = newPosts[2].UpdateAt + 1 + _, err = ss.Thread().UpdateMembership(m) + require.NoError(t, err) + th, err = ss.Thread().GetThreadForUser(newPosts[0].UserId, "", newPosts[0].Id, false) + require.NoError(t, err) + require.Equal(t, int64(0), th.UnreadReplies) + + editedPost := newPosts[2].Clone() + editedPost.Message = "This is an edited post" + _, err = ss.Post().Update(editedPost, newPosts[2]) + require.NoError(t, err) + + th, err = ss.Thread().GetThreadForUser(newPosts[0].UserId, "", newPosts[0].Id, false) + require.NoError(t, err) + require.Equal(t, int64(0), th.UnreadReplies) + }) }