From 6125b0ca7f175fcd1cec452be3271dfec0528a9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Espino=20Garc=C3=ADa?= Date: Tue, 24 Oct 2023 15:27:30 +0200 Subject: [PATCH] MM-54778 Fix mark as unread on GMs (#24880) * Fix mark as unread on GMs * Don't count own messages in gms when marking as unread * Change argument name * Rename userId --------- Co-authored-by: Mattermost Build --- server/channels/app/post.go | 8 ++-- server/channels/app/post_test.go | 43 ++++++++++++++++++ .../opentracinglayer/opentracinglayer.go | 8 ++-- .../channels/store/retrylayer/retrylayer.go | 8 ++-- .../channels/store/sqlstore/channel_store.go | 14 +++--- server/channels/store/store.go | 4 +- .../channels/store/storetest/channel_store.go | 45 ++++++++++++++----- .../store/storetest/mocks/ChannelStore.go | 26 +++++------ .../channels/store/timerlayer/timerlayer.go | 8 ++-- 9 files changed, 115 insertions(+), 49 deletions(-) diff --git a/server/channels/app/post.go b/server/channels/app/post.go index 85eaa32064..416188e838 100644 --- a/server/channels/app/post.go +++ b/server/channels/app/post.go @@ -1795,7 +1795,7 @@ func (a *App) countThreadMentions(c request.CTX, user *model.User, post *model.P count := 0 - if channel.Type == model.ChannelTypeDirect { + if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { // In a DM channel, every post made by the other user is a mention otherId := channel.GetOtherUserIdForDM(user.Id) for _, p := range posts { @@ -1842,16 +1842,16 @@ func (a *App) countMentionsFromPost(c request.CTX, user *model.User, post *model return 0, 0, 0, err } - if channel.Type == model.ChannelTypeDirect { + if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup { // In a DM channel, every post made by the other user is a mention - count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id)) + count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, user.Id) if nErr != nil { return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr) } var urgentCount int if a.IsPostPriorityEnabled() { - urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id)) + urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, user.Id) if nErr != nil { return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_urgent_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr) } diff --git a/server/channels/app/post_test.go b/server/channels/app/post_test.go index be440579fe..3f76e43e0c 100644 --- a/server/channels/app/post_test.go +++ b/server/channels/app/post_test.go @@ -1946,6 +1946,49 @@ func TestCountMentionsFromPost(t *testing.T) { assert.Equal(t, 0, count) }) + t.Run("should return the number of posts made by the other user for a group message", func(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + user1 := th.BasicUser + user2 := th.BasicUser2 + user3 := th.SystemAdminUser + + channel, err := th.App.createGroupChannel(th.Context, []string{user1.Id, user2.Id, user3.Id}) + require.Nil(t, err) + + post1, err := th.App.CreatePost(th.Context, &model.Post{ + UserId: user1.Id, + ChannelId: channel.Id, + Message: "test", + }, channel, false, true) + require.Nil(t, err) + + _, err = th.App.CreatePost(th.Context, &model.Post{ + UserId: user1.Id, + ChannelId: channel.Id, + Message: "test2", + }, channel, false, true) + require.Nil(t, err) + + _, err = th.App.CreatePost(th.Context, &model.Post{ + UserId: user3.Id, + ChannelId: channel.Id, + Message: "test3", + }, channel, false, true) + require.Nil(t, err) + + count, _, _, err := th.App.countMentionsFromPost(th.Context, user2, post1) + + assert.Nil(t, err) + assert.Equal(t, 3, count) + + count, _, _, err = th.App.countMentionsFromPost(th.Context, user1, post1) + + assert.Nil(t, err) + assert.Equal(t, 1, count) + }) + t.Run("should not count mentions from the before the given post", func(t *testing.T) { th := Setup(t).InitBasic() defer th.TearDown() diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go index 014374d2ed..894e39796a 100644 --- a/server/channels/store/opentracinglayer/opentracinglayer.go +++ b/server/channels/store/opentracinglayer/opentracinglayer.go @@ -757,7 +757,7 @@ func (s *OpenTracingLayerChannelStore) ClearSidebarOnTeamLeave(userID string, te return err } -func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { +func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountPostsAfter") s.Root.Store.SetContext(newCtx) @@ -766,7 +766,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta }() defer span.Finish() - result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) + result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -775,7 +775,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta return result, resultVar1, err } -func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { +func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountUrgentPostsAfter") s.Root.Store.SetContext(newCtx) @@ -784,7 +784,7 @@ func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, t }() defer span.Finish() - result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) + result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 22edcdc277..c755ee93ff 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -808,11 +808,11 @@ func (s *RetryLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s } -func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { +func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) { tries := 0 for { - result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) + result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID) if err == nil { return result, resultVar1, nil } @@ -829,11 +829,11 @@ func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int } -func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { +func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) { tries := 0 for { - result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) + result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID) if err == nil { return result, nil } diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 78341b9602..f54e5348c5 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -2644,7 +2644,7 @@ func (s SqlChannelStore) UpdateLastViewedAt(channelIds []string, userId string) return times, nil } -func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, userId string) (int, error) { +func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, error) { query := s.getQueryBuilder(). Select("count(*)"). From("PostsPriority"). @@ -2656,8 +2656,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64 sq.Eq{"Posts.DeleteAt": 0}, }) - if userId != "" { - query = query.Where(sq.Eq{"Posts.UserId": userId}) + if excludedUserID != "" { + query = query.Where(sq.NotEq{"Posts.UserId": excludedUserID}) } var urgent int64 @@ -2669,8 +2669,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64 return int(urgent), nil } -// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by that user. -func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, userId string) (int, int, error) { +// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by any other user. +func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, int, error) { joinLeavePostTypes := []string{ // These types correspond to the ones checked by Post.IsJoinLeaveMessage model.PostTypeJoinLeave, @@ -2694,8 +2694,8 @@ func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, user sq.Eq{"DeleteAt": 0}, }) - if userId != "" { - query = query.Where(sq.Eq{"UserId": userId}) + if excludedUserID != "" { + query = query.Where(sq.NotEq{"UserId": excludedUserID}) } sql, args, err := query.ToSql() if err != nil { diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 9886c3a2c5..6c06c19485 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -248,8 +248,8 @@ type ChannelStore interface { PermanentDeleteMembersByChannel(channelID string) error UpdateLastViewedAt(channelIds []string, userID string) (map[string]int64, error) UpdateLastViewedAtPost(unreadPost *model.Post, userID string, mentionCount, mentionCountRoot, urgentMentionCount int, setUnreadCountRoot bool) (*model.ChannelUnreadAt, error) - CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) - CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) + CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) + CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) IncrementMentionCount(channelID string, userIDs []string, isRoot, isUrgent bool) error AnalyticsTypeCount(teamID string, channelType model.ChannelType) (int64, error) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) diff --git a/server/channels/store/storetest/channel_store.go b/server/channels/store/storetest/channel_store.go index 49a38237c6..086a4b3a15 100644 --- a/server/channels/store/storetest/channel_store.go +++ b/server/channels/store/storetest/channel_store.go @@ -4479,6 +4479,7 @@ func testCountPostsAfter(t *testing.T, ss store.Store) { t.Run("should count all posts with or without the given user ID", func(t *testing.T) { userId1 := model.NewId() userId2 := model.NewId() + userId3 := model.NewId() channelId := model.NewId() @@ -4503,21 +4504,28 @@ func testCountPostsAfter(t *testing.T, ss store.Store) { }) require.NoError(t, err) + _, err = ss.Post().Save(&model.Post{ + UserId: userId3, + ChannelId: channelId, + CreateAt: 1003, + }) + require.NoError(t, err) + count, _, err := ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, "") require.NoError(t, err) - assert.Equal(t, 3, count) + assert.Equal(t, 4, count) count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, "") require.NoError(t, err) - assert.Equal(t, 2, count) + assert.Equal(t, 3, count) - count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId1) + count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId2) + require.NoError(t, err) + assert.Equal(t, 3, count) + + count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId2) require.NoError(t, err) assert.Equal(t, 2, count) - - count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId1) - require.NoError(t, err) - assert.Equal(t, 1, count) }) t.Run("should not count deleted posts", func(t *testing.T) { @@ -4623,6 +4631,7 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) { t.Run("should count all posts with or without the given user ID", func(t *testing.T) { userId1 := model.NewId() userId2 := model.NewId() + userId3 := model.NewId() channelId := model.NewId() @@ -4661,19 +4670,33 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) { }) require.NoError(t, err) + _, err = ss.Post().Save(&model.Post{ + UserId: userId3, + ChannelId: channelId, + CreateAt: 1003, + Metadata: &model.PostMetadata{ + Priority: &model.PostPriority{ + Priority: model.NewString(model.PostPriorityUrgent), + RequestedAck: model.NewBool(false), + PersistentNotifications: model.NewBool(false), + }, + }, + }) + require.NoError(t, err) + count, err := ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, "") require.NoError(t, err) - assert.Equal(t, 1, count) + assert.Equal(t, 2, count) count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, "") require.NoError(t, err) - assert.Equal(t, 0, count) + assert.Equal(t, 1, count) - count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId1) + count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId3) require.NoError(t, err) assert.Equal(t, 1, count) - count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId1) + count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId3) require.NoError(t, err) assert.Equal(t, 0, count) }) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index 53bb50726e..e272687ffe 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -184,30 +184,30 @@ func (_m *ChannelStore) ClearSidebarOnTeamLeave(userID string, teamID string) er return r0 } -// CountPostsAfter provides a mock function with given fields: channelID, timestamp, userID -func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { - ret := _m.Called(channelID, timestamp, userID) +// CountPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID +func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) { + ret := _m.Called(channelID, timestamp, excludedUserID) var r0 int var r1 int var r2 error if rf, ok := ret.Get(0).(func(string, int64, string) (int, int, error)); ok { - return rf(channelID, timestamp, userID) + return rf(channelID, timestamp, excludedUserID) } if rf, ok := ret.Get(0).(func(string, int64, string) int); ok { - r0 = rf(channelID, timestamp, userID) + r0 = rf(channelID, timestamp, excludedUserID) } else { r0 = ret.Get(0).(int) } if rf, ok := ret.Get(1).(func(string, int64, string) int); ok { - r1 = rf(channelID, timestamp, userID) + r1 = rf(channelID, timestamp, excludedUserID) } else { r1 = ret.Get(1).(int) } if rf, ok := ret.Get(2).(func(string, int64, string) error); ok { - r2 = rf(channelID, timestamp, userID) + r2 = rf(channelID, timestamp, excludedUserID) } else { r2 = ret.Error(2) } @@ -215,23 +215,23 @@ func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userI return r0, r1, r2 } -// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, userID -func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { - ret := _m.Called(channelID, timestamp, userID) +// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID +func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) { + ret := _m.Called(channelID, timestamp, excludedUserID) var r0 int var r1 error if rf, ok := ret.Get(0).(func(string, int64, string) (int, error)); ok { - return rf(channelID, timestamp, userID) + return rf(channelID, timestamp, excludedUserID) } if rf, ok := ret.Get(0).(func(string, int64, string) int); ok { - r0 = rf(channelID, timestamp, userID) + r0 = rf(channelID, timestamp, excludedUserID) } else { r0 = ret.Get(0).(int) } if rf, ok := ret.Get(1).(func(string, int64, string) error); ok { - r1 = rf(channelID, timestamp, userID) + r1 = rf(channelID, timestamp, excludedUserID) } else { r1 = ret.Error(1) } diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index b59768e2d9..c5975e5cd9 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -731,10 +731,10 @@ func (s *TimerLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s return err } -func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { +func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) { start := time.Now() - result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) + result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -747,10 +747,10 @@ func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int return result, resultVar1, err } -func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { +func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) { start := time.Now() - result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) + result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil {