MM-54778 Fix mark as unread on GMs (#24880)

* Fix mark as unread on GMs

* Don't count own messages in gms when marking as unread

* Change argument name

* Rename userId

---------

Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Daniel Espino García 2023-10-24 15:27:30 +02:00 committed by GitHub
parent 5d3ba7483b
commit 6125b0ca7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 115 additions and 49 deletions

View File

@ -1795,7 +1795,7 @@ func (a *App) countThreadMentions(c request.CTX, user *model.User, post *model.P
count := 0
if channel.Type == model.ChannelTypeDirect {
if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
// In a DM channel, every post made by the other user is a mention
otherId := channel.GetOtherUserIdForDM(user.Id)
for _, p := range posts {
@ -1842,16 +1842,16 @@ func (a *App) countMentionsFromPost(c request.CTX, user *model.User, post *model
return 0, 0, 0, err
}
if channel.Type == model.ChannelTypeDirect {
if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
// In a DM channel, every post made by the other user is a mention
count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id))
count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
if nErr != nil {
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
}
var urgentCount int
if a.IsPostPriorityEnabled() {
urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id))
urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
if nErr != nil {
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_urgent_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
}

View File

@ -1946,6 +1946,49 @@ func TestCountMentionsFromPost(t *testing.T) {
assert.Equal(t, 0, count)
})
t.Run("should return the number of posts made by the other user for a group message", func(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
user1 := th.BasicUser
user2 := th.BasicUser2
user3 := th.SystemAdminUser
channel, err := th.App.createGroupChannel(th.Context, []string{user1.Id, user2.Id, user3.Id})
require.Nil(t, err)
post1, err := th.App.CreatePost(th.Context, &model.Post{
UserId: user1.Id,
ChannelId: channel.Id,
Message: "test",
}, channel, false, true)
require.Nil(t, err)
_, err = th.App.CreatePost(th.Context, &model.Post{
UserId: user1.Id,
ChannelId: channel.Id,
Message: "test2",
}, channel, false, true)
require.Nil(t, err)
_, err = th.App.CreatePost(th.Context, &model.Post{
UserId: user3.Id,
ChannelId: channel.Id,
Message: "test3",
}, channel, false, true)
require.Nil(t, err)
count, _, _, err := th.App.countMentionsFromPost(th.Context, user2, post1)
assert.Nil(t, err)
assert.Equal(t, 3, count)
count, _, _, err = th.App.countMentionsFromPost(th.Context, user1, post1)
assert.Nil(t, err)
assert.Equal(t, 1, count)
})
t.Run("should not count mentions from the before the given post", func(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()

View File

@ -757,7 +757,7 @@ func (s *OpenTracingLayerChannelStore) ClearSidebarOnTeamLeave(userID string, te
return err
}
func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountPostsAfter")
s.Root.Store.SetContext(newCtx)
@ -766,7 +766,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
}()
defer span.Finish()
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
@ -775,7 +775,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
return result, resultVar1, err
}
func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountUrgentPostsAfter")
s.Root.Store.SetContext(newCtx)
@ -784,7 +784,7 @@ func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, t
}()
defer span.Finish()
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)

View File

@ -808,11 +808,11 @@ func (s *RetryLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
}
func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
tries := 0
for {
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
if err == nil {
return result, resultVar1, nil
}
@ -829,11 +829,11 @@ func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int
}
func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
tries := 0
for {
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
if err == nil {
return result, nil
}

View File

@ -2644,7 +2644,7 @@ func (s SqlChannelStore) UpdateLastViewedAt(channelIds []string, userId string)
return times, nil
}
func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, userId string) (int, error) {
func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, error) {
query := s.getQueryBuilder().
Select("count(*)").
From("PostsPriority").
@ -2656,8 +2656,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
sq.Eq{"Posts.DeleteAt": 0},
})
if userId != "" {
query = query.Where(sq.Eq{"Posts.UserId": userId})
if excludedUserID != "" {
query = query.Where(sq.NotEq{"Posts.UserId": excludedUserID})
}
var urgent int64
@ -2669,8 +2669,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
return int(urgent), nil
}
// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by that user.
func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, userId string) (int, int, error) {
// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by any other user.
func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, int, error) {
joinLeavePostTypes := []string{
// These types correspond to the ones checked by Post.IsJoinLeaveMessage
model.PostTypeJoinLeave,
@ -2694,8 +2694,8 @@ func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, user
sq.Eq{"DeleteAt": 0},
})
if userId != "" {
query = query.Where(sq.Eq{"UserId": userId})
if excludedUserID != "" {
query = query.Where(sq.NotEq{"UserId": excludedUserID})
}
sql, args, err := query.ToSql()
if err != nil {

View File

@ -248,8 +248,8 @@ type ChannelStore interface {
PermanentDeleteMembersByChannel(channelID string) error
UpdateLastViewedAt(channelIds []string, userID string) (map[string]int64, error)
UpdateLastViewedAtPost(unreadPost *model.Post, userID string, mentionCount, mentionCountRoot, urgentMentionCount int, setUnreadCountRoot bool) (*model.ChannelUnreadAt, error)
CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error)
CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error)
CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error)
CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error)
IncrementMentionCount(channelID string, userIDs []string, isRoot, isUrgent bool) error
AnalyticsTypeCount(teamID string, channelType model.ChannelType) (int64, error)
GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error)

View File

@ -4479,6 +4479,7 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
userId1 := model.NewId()
userId2 := model.NewId()
userId3 := model.NewId()
channelId := model.NewId()
@ -4503,21 +4504,28 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
})
require.NoError(t, err)
_, err = ss.Post().Save(&model.Post{
UserId: userId3,
ChannelId: channelId,
CreateAt: 1003,
})
require.NoError(t, err)
count, _, err := ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, "")
require.NoError(t, err)
assert.Equal(t, 3, count)
assert.Equal(t, 4, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, "")
require.NoError(t, err)
assert.Equal(t, 2, count)
assert.Equal(t, 3, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId1)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId2)
require.NoError(t, err)
assert.Equal(t, 3, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId2)
require.NoError(t, err)
assert.Equal(t, 2, count)
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId1)
require.NoError(t, err)
assert.Equal(t, 1, count)
})
t.Run("should not count deleted posts", func(t *testing.T) {
@ -4623,6 +4631,7 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
userId1 := model.NewId()
userId2 := model.NewId()
userId3 := model.NewId()
channelId := model.NewId()
@ -4661,19 +4670,33 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
})
require.NoError(t, err)
_, err = ss.Post().Save(&model.Post{
UserId: userId3,
ChannelId: channelId,
CreateAt: 1003,
Metadata: &model.PostMetadata{
Priority: &model.PostPriority{
Priority: model.NewString(model.PostPriorityUrgent),
RequestedAck: model.NewBool(false),
PersistentNotifications: model.NewBool(false),
},
},
})
require.NoError(t, err)
count, err := ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, "")
require.NoError(t, err)
assert.Equal(t, 1, count)
assert.Equal(t, 2, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, "")
require.NoError(t, err)
assert.Equal(t, 0, count)
assert.Equal(t, 1, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId1)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId3)
require.NoError(t, err)
assert.Equal(t, 1, count)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId1)
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId3)
require.NoError(t, err)
assert.Equal(t, 0, count)
})

View File

@ -184,30 +184,30 @@ func (_m *ChannelStore) ClearSidebarOnTeamLeave(userID string, teamID string) er
return r0
}
// CountPostsAfter provides a mock function with given fields: channelID, timestamp, userID
func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
ret := _m.Called(channelID, timestamp, userID)
// CountPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
ret := _m.Called(channelID, timestamp, excludedUserID)
var r0 int
var r1 int
var r2 error
if rf, ok := ret.Get(0).(func(string, int64, string) (int, int, error)); ok {
return rf(channelID, timestamp, userID)
return rf(channelID, timestamp, excludedUserID)
}
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
r0 = rf(channelID, timestamp, userID)
r0 = rf(channelID, timestamp, excludedUserID)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func(string, int64, string) int); ok {
r1 = rf(channelID, timestamp, userID)
r1 = rf(channelID, timestamp, excludedUserID)
} else {
r1 = ret.Get(1).(int)
}
if rf, ok := ret.Get(2).(func(string, int64, string) error); ok {
r2 = rf(channelID, timestamp, userID)
r2 = rf(channelID, timestamp, excludedUserID)
} else {
r2 = ret.Error(2)
}
@ -215,23 +215,23 @@ func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userI
return r0, r1, r2
}
// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, userID
func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
ret := _m.Called(channelID, timestamp, userID)
// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
ret := _m.Called(channelID, timestamp, excludedUserID)
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func(string, int64, string) (int, error)); ok {
return rf(channelID, timestamp, userID)
return rf(channelID, timestamp, excludedUserID)
}
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
r0 = rf(channelID, timestamp, userID)
r0 = rf(channelID, timestamp, excludedUserID)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func(string, int64, string) error); ok {
r1 = rf(channelID, timestamp, userID)
r1 = rf(channelID, timestamp, excludedUserID)
} else {
r1 = ret.Error(1)
}

View File

@ -731,10 +731,10 @@ func (s *TimerLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
return err
}
func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
start := time.Now()
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@ -747,10 +747,10 @@ func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int
return result, resultVar1, err
}
func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
start := time.Now()
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {