mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-54778 Fix mark as unread on GMs (#24880)
* Fix mark as unread on GMs * Don't count own messages in gms when marking as unread * Change argument name * Rename userId --------- Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
parent
5d3ba7483b
commit
6125b0ca7f
@ -1795,7 +1795,7 @@ func (a *App) countThreadMentions(c request.CTX, user *model.User, post *model.P
|
||||
|
||||
count := 0
|
||||
|
||||
if channel.Type == model.ChannelTypeDirect {
|
||||
if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
|
||||
// In a DM channel, every post made by the other user is a mention
|
||||
otherId := channel.GetOtherUserIdForDM(user.Id)
|
||||
for _, p := range posts {
|
||||
@ -1842,16 +1842,16 @@ func (a *App) countMentionsFromPost(c request.CTX, user *model.User, post *model
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
|
||||
if channel.Type == model.ChannelTypeDirect {
|
||||
if channel.Type == model.ChannelTypeDirect || channel.Type == model.ChannelTypeGroup {
|
||||
// In a DM channel, every post made by the other user is a mention
|
||||
count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id))
|
||||
count, countRoot, nErr := a.Srv().Store().Channel().CountPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
|
||||
if nErr != nil {
|
||||
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
|
||||
}
|
||||
|
||||
var urgentCount int
|
||||
if a.IsPostPriorityEnabled() {
|
||||
urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, channel.GetOtherUserIdForDM(user.Id))
|
||||
urgentCount, nErr = a.Srv().Store().Channel().CountUrgentPostsAfter(post.ChannelId, post.CreateAt-1, user.Id)
|
||||
if nErr != nil {
|
||||
return 0, 0, 0, model.NewAppError("countMentionsFromPost", "app.channel.count_urgent_posts_since.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
|
||||
}
|
||||
|
@ -1946,6 +1946,49 @@ func TestCountMentionsFromPost(t *testing.T) {
|
||||
assert.Equal(t, 0, count)
|
||||
})
|
||||
|
||||
t.Run("should return the number of posts made by the other user for a group message", func(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
user1 := th.BasicUser
|
||||
user2 := th.BasicUser2
|
||||
user3 := th.SystemAdminUser
|
||||
|
||||
channel, err := th.App.createGroupChannel(th.Context, []string{user1.Id, user2.Id, user3.Id})
|
||||
require.Nil(t, err)
|
||||
|
||||
post1, err := th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: user1.Id,
|
||||
ChannelId: channel.Id,
|
||||
Message: "test",
|
||||
}, channel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: user1.Id,
|
||||
ChannelId: channel.Id,
|
||||
Message: "test2",
|
||||
}, channel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = th.App.CreatePost(th.Context, &model.Post{
|
||||
UserId: user3.Id,
|
||||
ChannelId: channel.Id,
|
||||
Message: "test3",
|
||||
}, channel, false, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
count, _, _, err := th.App.countMentionsFromPost(th.Context, user2, post1)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
count, _, _, err = th.App.countMentionsFromPost(th.Context, user1, post1)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("should not count mentions from the before the given post", func(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
@ -757,7 +757,7 @@ func (s *OpenTracingLayerChannelStore) ClearSidebarOnTeamLeave(userID string, te
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
|
||||
func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountPostsAfter")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
@ -766,7 +766,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
@ -775,7 +775,7 @@ func (s *OpenTracingLayerChannelStore) CountPostsAfter(channelID string, timesta
|
||||
return result, resultVar1, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
|
||||
func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.CountUrgentPostsAfter")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
@ -784,7 +784,7 @@ func (s *OpenTracingLayerChannelStore) CountUrgentPostsAfter(channelID string, t
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
|
@ -808,11 +808,11 @@ func (s *RetryLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
|
||||
func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
|
||||
if err == nil {
|
||||
return result, resultVar1, nil
|
||||
}
|
||||
@ -829,11 +829,11 @@ func (s *RetryLayerChannelStore) CountPostsAfter(channelID string, timestamp int
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
|
||||
func (s *RetryLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
@ -2644,7 +2644,7 @@ func (s SqlChannelStore) UpdateLastViewedAt(channelIds []string, userId string)
|
||||
return times, nil
|
||||
}
|
||||
|
||||
func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, userId string) (int, error) {
|
||||
func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, error) {
|
||||
query := s.getQueryBuilder().
|
||||
Select("count(*)").
|
||||
From("PostsPriority").
|
||||
@ -2656,8 +2656,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
|
||||
sq.Eq{"Posts.DeleteAt": 0},
|
||||
})
|
||||
|
||||
if userId != "" {
|
||||
query = query.Where(sq.Eq{"Posts.UserId": userId})
|
||||
if excludedUserID != "" {
|
||||
query = query.Where(sq.NotEq{"Posts.UserId": excludedUserID})
|
||||
}
|
||||
|
||||
var urgent int64
|
||||
@ -2669,8 +2669,8 @@ func (s SqlChannelStore) CountUrgentPostsAfter(channelId string, timestamp int64
|
||||
return int(urgent), nil
|
||||
}
|
||||
|
||||
// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by that user.
|
||||
func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, userId string) (int, int, error) {
|
||||
// CountPostsAfter returns the number of posts in the given channel created after but not including the given timestamp. If given a non-empty user ID, only counts posts made by any other user.
|
||||
func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, excludedUserID string) (int, int, error) {
|
||||
joinLeavePostTypes := []string{
|
||||
// These types correspond to the ones checked by Post.IsJoinLeaveMessage
|
||||
model.PostTypeJoinLeave,
|
||||
@ -2694,8 +2694,8 @@ func (s SqlChannelStore) CountPostsAfter(channelId string, timestamp int64, user
|
||||
sq.Eq{"DeleteAt": 0},
|
||||
})
|
||||
|
||||
if userId != "" {
|
||||
query = query.Where(sq.Eq{"UserId": userId})
|
||||
if excludedUserID != "" {
|
||||
query = query.Where(sq.NotEq{"UserId": excludedUserID})
|
||||
}
|
||||
sql, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
|
@ -248,8 +248,8 @@ type ChannelStore interface {
|
||||
PermanentDeleteMembersByChannel(channelID string) error
|
||||
UpdateLastViewedAt(channelIds []string, userID string) (map[string]int64, error)
|
||||
UpdateLastViewedAtPost(unreadPost *model.Post, userID string, mentionCount, mentionCountRoot, urgentMentionCount int, setUnreadCountRoot bool) (*model.ChannelUnreadAt, error)
|
||||
CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error)
|
||||
CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error)
|
||||
CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error)
|
||||
CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error)
|
||||
IncrementMentionCount(channelID string, userIDs []string, isRoot, isUrgent bool) error
|
||||
AnalyticsTypeCount(teamID string, channelType model.ChannelType) (int64, error)
|
||||
GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error)
|
||||
|
@ -4479,6 +4479,7 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
|
||||
t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
|
||||
userId1 := model.NewId()
|
||||
userId2 := model.NewId()
|
||||
userId3 := model.NewId()
|
||||
|
||||
channelId := model.NewId()
|
||||
|
||||
@ -4503,21 +4504,28 @@ func testCountPostsAfter(t *testing.T, ss store.Store) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ss.Post().Save(&model.Post{
|
||||
UserId: userId3,
|
||||
ChannelId: channelId,
|
||||
CreateAt: 1003,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, _, err := ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
assert.Equal(t, 4, count)
|
||||
|
||||
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId1)
|
||||
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt-1, userId2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
count, _, err = ss.Channel().CountPostsAfter(channelId, p1.CreateAt, userId1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("should not count deleted posts", func(t *testing.T) {
|
||||
@ -4623,6 +4631,7 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
|
||||
t.Run("should count all posts with or without the given user ID", func(t *testing.T) {
|
||||
userId1 := model.NewId()
|
||||
userId2 := model.NewId()
|
||||
userId3 := model.NewId()
|
||||
|
||||
channelId := model.NewId()
|
||||
|
||||
@ -4661,19 +4670,33 @@ func testCountUrgentPostsAfter(t *testing.T, ss store.Store) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ss.Post().Save(&model.Post{
|
||||
UserId: userId3,
|
||||
ChannelId: channelId,
|
||||
CreateAt: 1003,
|
||||
Metadata: &model.PostMetadata{
|
||||
Priority: &model.PostPriority{
|
||||
Priority: model.NewString(model.PostPriorityUrgent),
|
||||
RequestedAck: model.NewBool(false),
|
||||
PersistentNotifications: model.NewBool(false),
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId1)
|
||||
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt-1, userId3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId1)
|
||||
count, err = ss.Channel().CountUrgentPostsAfter(channelId, p1.CreateAt, userId3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
})
|
||||
|
@ -184,30 +184,30 @@ func (_m *ChannelStore) ClearSidebarOnTeamLeave(userID string, teamID string) er
|
||||
return r0
|
||||
}
|
||||
|
||||
// CountPostsAfter provides a mock function with given fields: channelID, timestamp, userID
|
||||
func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
|
||||
ret := _m.Called(channelID, timestamp, userID)
|
||||
// CountPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
|
||||
func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
|
||||
ret := _m.Called(channelID, timestamp, excludedUserID)
|
||||
|
||||
var r0 int
|
||||
var r1 int
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(string, int64, string) (int, int, error)); ok {
|
||||
return rf(channelID, timestamp, userID)
|
||||
return rf(channelID, timestamp, excludedUserID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
|
||||
r0 = rf(channelID, timestamp, userID)
|
||||
r0 = rf(channelID, timestamp, excludedUserID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, int64, string) int); ok {
|
||||
r1 = rf(channelID, timestamp, userID)
|
||||
r1 = rf(channelID, timestamp, excludedUserID)
|
||||
} else {
|
||||
r1 = ret.Get(1).(int)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(2).(func(string, int64, string) error); ok {
|
||||
r2 = rf(channelID, timestamp, userID)
|
||||
r2 = rf(channelID, timestamp, excludedUserID)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
@ -215,23 +215,23 @@ func (_m *ChannelStore) CountPostsAfter(channelID string, timestamp int64, userI
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, userID
|
||||
func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
|
||||
ret := _m.Called(channelID, timestamp, userID)
|
||||
// CountUrgentPostsAfter provides a mock function with given fields: channelID, timestamp, excludedUserID
|
||||
func (_m *ChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
|
||||
ret := _m.Called(channelID, timestamp, excludedUserID)
|
||||
|
||||
var r0 int
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string, int64, string) (int, error)); ok {
|
||||
return rf(channelID, timestamp, userID)
|
||||
return rf(channelID, timestamp, excludedUserID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, int64, string) int); ok {
|
||||
r0 = rf(channelID, timestamp, userID)
|
||||
r0 = rf(channelID, timestamp, excludedUserID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, int64, string) error); ok {
|
||||
r1 = rf(channelID, timestamp, userID)
|
||||
r1 = rf(channelID, timestamp, excludedUserID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
@ -731,10 +731,10 @@ func (s *TimerLayerChannelStore) ClearSidebarOnTeamLeave(userID string, teamID s
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, userID string) (int, int, error) {
|
||||
func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, int, error) {
|
||||
start := time.Now()
|
||||
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, userID)
|
||||
result, resultVar1, err := s.ChannelStore.CountPostsAfter(channelID, timestamp, excludedUserID)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
@ -747,10 +747,10 @@ func (s *TimerLayerChannelStore) CountPostsAfter(channelID string, timestamp int
|
||||
return result, resultVar1, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, userID string) (int, error) {
|
||||
func (s *TimerLayerChannelStore) CountUrgentPostsAfter(channelID string, timestamp int64, excludedUserID string) (int, error) {
|
||||
start := time.Now()
|
||||
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, userID)
|
||||
result, err := s.ChannelStore.CountUrgentPostsAfter(channelID, timestamp, excludedUserID)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user