MM-54778 Fix mark as unread on GMs (#24880)

* Fix mark as unread on GMs

* Don't count own messages in gms when marking as unread

* Change argument name

* Rename userId

---------

Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Daniel Espino García 2023-10-24 15:27:30 +02:00 committed by GitHub
parent 5d3ba7483b
commit 6125b0ca7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 115 additions and 49 deletions

View File

@ -1795,7 +1795,7 @@ func (a *App) countThreadMentions(c request.CTX, user *model.User, post *model.P
count := 0 count := 0
if channel.Type == model.ChannelTypeDirect { if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
// In a DM channel, every post made by the other user is a mention // In a DM channel, every post made by the other user is a mention
otherId := channel.GetOtherUserIdForDM(user.Id) otherId := channel.GetOtherUserIdForDM(user.Id)
for _, p := range posts { for _, p := range posts {
@ -1842,16 +1842,16 @@ func (a *App) countMentionsFromPost(c request.CTX, user *model.User, post *model
return 0, 0, 0, err return 0, 0, 0, err
} }
if channel.Type == model.ChannelTypeDirect { if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
// In a DM channel, every post made by the other user is a mention // In a DM channel, every post made by the other user is a mention
count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id)) count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
if nErr != nil { if nErr != nil {
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr) return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
} }
var urgentCount int var urgentCount int
if a.IsPostPriorityEnabled() { if a.IsPostPriorityEnabled() {
urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id)) urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
if nErr != nil { if nErr != nil {
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_urgent_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr) return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_urgent_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
} }

View File

@ -1946,6 +1946,49 @@ func TestCountMentionsFromPost(t *testing.T) {
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
}) })
t.Run("should return the number of posts made by the other user for a group message", func(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
user1 := th.BasicUser
user2 := th.BasicUser2
user3 := th.SystemAdminUser
channel, err := th.App.createGroupChannel(th.Context, []string{user1.Id, user2.Id, user3.Id})
require.Nil(t, err)
post1, err := th.App.CreatePost(th.Context, &model.Post{
UserId: user1.Id,
ChannelId: channel.Id,
Message: "test",
}, channel, false, true)
require.Nil(t, err)
_, err = th.App.CreatePost(th.Context, &model.Post{
UserId: user1.Id,
ChannelId: channel.Id,
Message: "test2",
}, channel, false, true)
require.Nil(t, err)
_, err = th.App.CreatePost(th.Context, &model.Post{
UserId: user3.Id,
ChannelId: channel.Id,
Message: "test3",
}, channel, false, true)
require.Nil(t, err)
count, _, _, err := th.App.countMentionsFromPost(th.Context, user2, post1)
assert.Nil(t, err)
assert.Equal(t, 3, count)
count, _, _, err = th.App.countMentionsFromPost(th.Context, user1, post1)
assert.Nil(t, err)
assert.Equal(t, 1, count)
})
t.Run("should not count mentions from the before the given post", func(t *testing.T) { t.Run("should not count mentions from the before the given post", func(t *testing.T) {
th := Setup(t).InitBasic() th := Setup(t).InitBasic()
defer th.TearDown() defer th.TearDown()

View File

@ -757,7 +757,7 @@ func (s *OpenTracingLayerChannelStore) ClearSidebarOnTeamLeave(userID string, te
return err return err
} }
func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
origCtx := s.Root.Store.Context() origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountPostsAfter") span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountPostsAfter")
s.Root.Store.SetContext(newCtx) s.Root.Store.SetContext(newCtx)
@ -766,7 +766,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
}() }()
defer span.Finish() defer span.Finish()
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
if err != nil { if err != nil {
span.LogFields(spanlog.Error(err)) span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true) ext.Error.Set(span, true)
@ -775,7 +775,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
return result, resultVar1, err return result, resultVar1, err
} }
func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
origCtx := s.Root.Store.Context() origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountUrgentPostsAfter") span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountUrgentPostsAfter")
s.Root.Store.SetContext(newCtx) s.Root.Store.SetContext(newCtx)
@ -784,7 +784,7 @@ func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, t
}() }()
defer span.Finish() defer span.Finish()
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
if err != nil { if err != nil {
span.LogFields(spanlog.Error(err)) span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true) ext.Error.Set(span, true)

View File

@ -808,11 +808,11 @@ func (s *RetryLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
} }
func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
tries := 0 tries := 0
for { for {
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
if err == nil { if err == nil {
return result, resultVar1, nil return result, resultVar1, nil
} }
@ -829,11 +829,11 @@ func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int
} }
func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
tries := 0 tries := 0
for { for {
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
if err == nil { if err == nil {
return result, nil return result, nil
} }

View File

@ -2644,7 +2644,7 @@ func (s SqlChannelStore) UpdateLastViewedAt(channelIds []string, userId string)
return times, nil return times, nil
} }
func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, userId string) (int, error) { func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, error) {
query := s.getQueryBuilder(). query := s.getQueryBuilder().
Select("count(*)"). Select("count(*)").
From("PostsPriority"). From("PostsPriority").
@ -2656,8 +2656,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
sq.Eq{"Posts.DeleteAt": 0}, sq.Eq{"Posts.DeleteAt": 0},
}) })
if userId != "" { if excludedUserID != "" {
query = query.Where(sq.Eq{"Posts.UserId": userId}) query = query.Where(sq.NotEq{"Posts.UserId": excludedUserID})
} }
var urgent int64 var urgent int64
@ -2669,8 +2669,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
return int(urgent), nil return int(urgent), nil
} }
// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by that user. // CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by any other user.
func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, userId string) (int, int, error) { func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, int, error) {
joinLeavePostTypes := []string{ joinLeavePostTypes := []string{
// These types correspond to the ones checked by Post.IsJoinLeaveMessage // These types correspond to the ones checked by Post.IsJoinLeaveMessage
model.PostTypeJoinLeave, model.PostTypeJoinLeave,
@ -2694,8 +2694,8 @@ func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, user
sq.Eq{"DeleteAt": 0}, sq.Eq{"DeleteAt": 0},
}) })
if userId != "" { if excludedUserID != "" {
query = query.Where(sq.Eq{"UserId": userId}) query = query.Where(sq.NotEq{"UserId": excludedUserID})
} }
sql, args, err := query.ToSql() sql, args, err := query.ToSql()
if err != nil { if err != nil {

View File

@ -248,8 +248,8 @@ type ChannelStore interface {
PermanentDeleteMembersByChannel(channelID string) error PermanentDeleteMembersByChannel(channelID string) error
UpdateLastViewedAt(channelIds []string, userID string) (map[string]int64, error) UpdateLastViewedAt(channelIds []string, userID string) (map[string]int64, error)
UpdateLastViewedAtPost(unreadPost *model.Post, userID string, mentionCount, mentionCountRoot, urgentMentionCount int, setUnreadCountRoot bool) (*model.ChannelUnreadAt, error) UpdateLastViewedAtPost(unreadPost *model.Post, userID string, mentionCount, mentionCountRoot, urgentMentionCount int, setUnreadCountRoot bool) (*model.ChannelUnreadAt, error)
CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error)
CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error)
IncrementMentionCount(channelID string, userIDs []string, isRoot, isUrgent bool) error IncrementMentionCount(channelID string, userIDs []string, isRoot, isUrgent bool) error
AnalyticsTypeCount(teamID string, channelType model.ChannelType) (int64, error) AnalyticsTypeCount(teamID string, channelType model.ChannelType) (int64, error)
GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error)

View File

@ -4479,6 +4479,7 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
t.Run("should count all posts with or without the given user ID", func(t *testing.T) { t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
userId1 := model.NewId() userId1 := model.NewId()
userId2 := model.NewId() userId2 := model.NewId()
userId3 := model.NewId()
channelId := model.NewId() channelId := model.NewId()
@ -4503,21 +4504,28 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ss.Post().Save(&model.Post{
UserId: userId3,
ChannelId: channelId,
CreateAt: 1003,
})
require.NoError(t, err)
count, _, err := ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, "") count, _, err := ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 3, count) assert.Equal(t, 4, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, "") count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, count) assert.Equal(t, 3, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId1) count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId2)
require.NoError(t, err)
assert.Equal(t, 3, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId2)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, count) assert.Equal(t, 2, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId1)
require.NoError(t, err)
assert.Equal(t, 1, count)
}) })
t.Run("should not count deleted posts", func(t *testing.T) { t.Run("should not count deleted posts", func(t *testing.T) {
@ -4623,6 +4631,7 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
t.Run("should count all posts with or without the given user ID", func(t *testing.T) { t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
userId1 := model.NewId() userId1 := model.NewId()
userId2 := model.NewId() userId2 := model.NewId()
userId3 := model.NewId()
channelId := model.NewId() channelId := model.NewId()
@ -4661,19 +4670,33 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ss.Post().Save(&model.Post{
UserId: userId3,
ChannelId: channelId,
CreateAt: 1003,
Metadata: &model.PostMetadata{
Priority: &model.PostPriority{
Priority: model.NewString(model.PostPriorityUrgent),
RequestedAck: model.NewBool(false),
PersistentNotifications: model.NewBool(false),
},
},
})
require.NoError(t, err)
count, err := ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, "") count, err := ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, count) assert.Equal(t, 2, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, "") count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, count) assert.Equal(t, 1, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId1) count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId3)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, count) assert.Equal(t, 1, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId1) count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId3)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
}) })

View File

@ -184,30 +184,30 @@ func (_m *ChannelStore) ClearSidebarOnTeamLeave(userID string, teamID string) er
return r0 return r0
} }
// CountPostsAfter provides a mock function with given fields: channelID, timestamp, userID // CountPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
ret := _m.Called(channelID, timestamp, userID) ret := _m.Called(channelID, timestamp, excludedUserID)
var r0 int var r0 int
var r1 int var r1 int
var r2 error var r2 error
if rf, ok := ret.Get(0).(func(string, int64, string) (int, int, error)); ok { if rf, ok := ret.Get(0).(func(string, int64, string) (int, int, error)); ok {
return rf(channelID, timestamp, userID) return rf(channelID, timestamp, excludedUserID)
} }
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok { if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
r0 = rf(channelID, timestamp, userID) r0 = rf(channelID, timestamp, excludedUserID)
} else { } else {
r0 = ret.Get(0).(int) r0 = ret.Get(0).(int)
} }
if rf, ok := ret.Get(1).(func(string, int64, string) int); ok { if rf, ok := ret.Get(1).(func(string, int64, string) int); ok {
r1 = rf(channelID, timestamp, userID) r1 = rf(channelID, timestamp, excludedUserID)
} else { } else {
r1 = ret.Get(1).(int) r1 = ret.Get(1).(int)
} }
if rf, ok := ret.Get(2).(func(string, int64, string) error); ok { if rf, ok := ret.Get(2).(func(string, int64, string) error); ok {
r2 = rf(channelID, timestamp, userID) r2 = rf(channelID, timestamp, excludedUserID)
} else { } else {
r2 = ret.Error(2) r2 = ret.Error(2)
} }
@ -215,23 +215,23 @@ func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userI
return r0, r1, r2 return r0, r1, r2
} }
// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, userID // CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
ret := _m.Called(channelID, timestamp, userID) ret := _m.Called(channelID, timestamp, excludedUserID)
var r0 int var r0 int
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(string, int64, string) (int, error)); ok { if rf, ok := ret.Get(0).(func(string, int64, string) (int, error)); ok {
return rf(channelID, timestamp, userID) return rf(channelID, timestamp, excludedUserID)
} }
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok { if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
r0 = rf(channelID, timestamp, userID) r0 = rf(channelID, timestamp, excludedUserID)
} else { } else {
r0 = ret.Get(0).(int) r0 = ret.Get(0).(int)
} }
if rf, ok := ret.Get(1).(func(string, int64, string) error); ok { if rf, ok := ret.Get(1).(func(string, int64, string) error); ok {
r1 = rf(channelID, timestamp, userID) r1 = rf(channelID, timestamp, excludedUserID)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@ -731,10 +731,10 @@ func (s *TimerLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
return err return err
} }
func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) { func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
start := time.Now() start := time.Now()
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID) result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
elapsed := float64(time.Since(start)) / float64(time.Second) elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil { if s.Root.Metrics != nil {
@ -747,10 +747,10 @@ func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int
return result, resultVar1, err return result, resultVar1, err
} }
func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) { func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
start := time.Now() start := time.Now()
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID) result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
elapsed := float64(time.Since(start)) / float64(time.Second) elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil { if s.Root.Metrics != nil {