mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-32525 Incorrect number of mentions for channels when threads are enabled (#16853)
This commit is contained in:
21
api4/user.go
21
api4/user.go
@@ -94,6 +94,7 @@ func (api *API) InitUser() {
|
||||
|
||||
api.BaseRoutes.UserThreads.Handle("", api.ApiSessionRequired(getThreadsForUser)).Methods("GET")
|
||||
api.BaseRoutes.UserThreads.Handle("/read", api.ApiSessionRequired(updateReadStateAllThreadsByUser)).Methods("PUT")
|
||||
api.BaseRoutes.UserThreads.Handle("/mention_counts", api.ApiSessionRequired(getMentionCountsForAllThreadsByUser)).Methods("GET")
|
||||
|
||||
api.BaseRoutes.UserThread.Handle("", api.ApiSessionRequired(getThreadForUser)).Methods("GET")
|
||||
api.BaseRoutes.UserThread.Handle("/following", api.ApiSessionRequired(followThreadByUser)).Methods("PUT")
|
||||
@@ -2842,6 +2843,26 @@ func getThreadForUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(threads.ToJson()))
|
||||
}
|
||||
|
||||
func getMentionCountsForAllThreadsByUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.RequireUserId().RequireTeamId()
|
||||
if c.Err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.App.SessionHasPermissionToUser(*c.App.Session(), c.Params.UserId) {
|
||||
c.SetPermissionError(model.PERMISSION_EDIT_OTHER_USERS)
|
||||
return
|
||||
}
|
||||
counts, err := c.App.GetThreadMentionsForUserPerChannel(c.Params.UserId, c.Params.TeamId)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
return
|
||||
}
|
||||
resp, _ := json.Marshal(counts)
|
||||
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func getThreadsForUser(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.RequireUserId().RequireTeamId()
|
||||
if c.Err != nil {
|
||||
|
||||
@@ -5754,56 +5754,66 @@ func TestMaintainUnreadMentionsInThread(t *testing.T) {
|
||||
*cfg.ServiceSettings.ThreadAutoFollow = true
|
||||
*cfg.ServiceSettings.CollapsedThreads = model.COLLAPSED_THREADS_DEFAULT_ON
|
||||
})
|
||||
|
||||
checkThreadList := func(client *model.Client4, userId string, expectedThreads int) (*model.Threads, *model.Response) {
|
||||
checkMentionCounts := func(client *model.Client4, userId string, expected map[string]int64) {
|
||||
actual, resp2 := client.GetThreadMentionsForUserPerChannel(userId, th.BasicTeam.Id)
|
||||
CheckNoError(t, resp2)
|
||||
require.EqualValues(t, expected, actual)
|
||||
}
|
||||
checkThreadList := func(client *model.Client4, userId string, expectedMentions, expectedThreads int) (*model.Threads, *model.Response) {
|
||||
uss, resp := client.GetUserThreads(userId, th.BasicTeam.Id, model.GetUserThreadsOpts{
|
||||
Deleted: false,
|
||||
})
|
||||
CheckNoError(t, resp)
|
||||
|
||||
require.Len(t, uss.Threads, expectedThreads)
|
||||
sum := int64(0)
|
||||
for _, thr := range uss.Threads {
|
||||
sum += thr.UnreadMentions
|
||||
}
|
||||
require.Equal(t, sum, uss.TotalUnreadMentions)
|
||||
require.EqualValues(t, expectedMentions, uss.TotalUnreadMentions)
|
||||
|
||||
return uss, resp
|
||||
}
|
||||
|
||||
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.BasicUser.Id)
|
||||
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
|
||||
|
||||
// create regular post
|
||||
rpost, _ := postAndCheck(t, Client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testMsg"})
|
||||
// create reply and mention the original poster and another user
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply @" + th.BasicUser.Username + " and @" + th.BasicUser2.Username, RootId: rpost.Id})
|
||||
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.BasicUser.Id)
|
||||
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
|
||||
|
||||
checkMentionCounts(Client, th.BasicUser.Id, map[string]int64{th.BasicChannel.Id: 1})
|
||||
// basic user 1 was mentioned 1 time
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1)
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
|
||||
// basic user 2 was mentioned 1 time
|
||||
checkThreadList(th.SystemAdminClient, th.BasicUser2.Id, 1)
|
||||
checkThreadList(th.SystemAdminClient, th.BasicUser2.Id, 1, 1)
|
||||
|
||||
// test self mention, shouldn't increase mention count
|
||||
postAndCheck(t, Client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply @" + th.BasicUser.Username, RootId: rpost.Id})
|
||||
// count should increase
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1)
|
||||
// count shouldn't increase
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
|
||||
|
||||
// test DM
|
||||
dm := th.CreateDmChannel(th.SystemAdminUser)
|
||||
dm_root_post, _ := postAndCheck(t, Client, &model.Post{ChannelId: dm.Id, Message: "hi @" + th.SystemAdminUser.Username})
|
||||
|
||||
// no changes
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1)
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
|
||||
|
||||
// post reply by the same user
|
||||
postAndCheck(t, Client, &model.Post{ChannelId: dm.Id, Message: "how are you", RootId: dm_root_post.Id})
|
||||
|
||||
// thread created
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 2)
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 1, 2)
|
||||
|
||||
// post two replies by another user, without mentions. mention count should still increase since this is a DM
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: dm.Id, Message: "msg1", RootId: dm_root_post.Id})
|
||||
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: dm.Id, Message: "msg2", RootId: dm_root_post.Id})
|
||||
// expect increment by two mentions
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 2)
|
||||
checkThreadList(th.Client, th.BasicUser.Id, 3, 2)
|
||||
checkMentionCounts(Client, th.BasicUser.Id, map[string]int64{th.BasicChannel.Id: 1, dm.Id: 2})
|
||||
}
|
||||
|
||||
func TestReadThreads(t *testing.T) {
|
||||
|
||||
@@ -699,6 +699,7 @@ type AppIface interface {
|
||||
GetTermsOfService(id string) (*model.TermsOfService, *model.AppError)
|
||||
GetThreadForUser(userID, teamID, threadId string, extended bool) (*model.ThreadResponse, *model.AppError)
|
||||
GetThreadMembershipsForUser(userID, teamID string) ([]*model.ThreadMembership, error)
|
||||
GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *model.AppError)
|
||||
GetThreadsForUser(userID, teamID string, options model.GetUserThreadsOpts) (*model.Threads, *model.AppError)
|
||||
GetUploadSession(uploadId string) (*model.UploadSession, *model.AppError)
|
||||
GetUploadSessionsForUser(userID string) ([]*model.UploadSession, *model.AppError)
|
||||
|
||||
@@ -8612,6 +8612,28 @@ func (a *OpenTracingAppLayer) GetThreadMembershipsForUser(userID string, teamID
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, *model.AppError) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetThreadMentionsForUserPerChannel")
|
||||
|
||||
a.ctx = newCtx
|
||||
a.app.Srv().Store.SetContext(newCtx)
|
||||
defer func() {
|
||||
a.app.Srv().Store.SetContext(origCtx)
|
||||
a.ctx = origCtx
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
resultVar0, resultVar1 := a.app.GetThreadMentionsForUserPerChannel(userId, teamId)
|
||||
|
||||
if resultVar1 != nil {
|
||||
span.LogFields(spanlog.Error(resultVar1))
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) GetThreadsForUser(userID string, teamID string, options model.GetUserThreadsOpts) (*model.Threads, *model.AppError) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetThreadsForUser")
|
||||
|
||||
@@ -2368,6 +2368,14 @@ func (a *App) GetThreadsForUser(userID, teamID string, options model.GetUserThre
|
||||
return threads, nil
|
||||
}
|
||||
|
||||
func (a *App) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *model.AppError) {
|
||||
res, err := a.Srv().Store.Thread().GetThreadMentionsForUserPerChannel(userId, teamId)
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("GetThreadMentionsForUserPerChannel", "app.user.get_threads_for_user.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *App) GetThreadForUser(userID, teamID, threadId string, extended bool) (*model.ThreadResponse, *model.AppError) {
|
||||
thread, err := a.Srv().Store.Thread().GetThreadForUser(userID, teamID, threadId, extended)
|
||||
if err != nil {
|
||||
|
||||
@@ -5816,6 +5816,20 @@ func (c *Client4) ListImports() ([]string, *Response) {
|
||||
return ArrayFromJson(r.Body), BuildResponse(r)
|
||||
}
|
||||
|
||||
func (c *Client4) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *Response) {
|
||||
url := c.GetUserThreadsRoute(userId, teamId)
|
||||
r, appErr := c.DoApiGet(url+"/mention_counts", "")
|
||||
if appErr != nil {
|
||||
return nil, BuildErrorResponse(r, appErr)
|
||||
}
|
||||
defer closeBody(r)
|
||||
|
||||
var counts map[string]int64
|
||||
json.NewDecoder(r.Body).Decode(&counts)
|
||||
|
||||
return counts, BuildResponse(r)
|
||||
}
|
||||
|
||||
func (c *Client4) GetUserThreads(userId, teamId string, options GetUserThreadsOpts) (*Threads, *Response) {
|
||||
v := url.Values{}
|
||||
if options.Since != 0 {
|
||||
|
||||
@@ -7846,6 +7846,24 @@ func (s *OpenTracingLayerThreadStore) GetThreadForUser(userId string, teamId str
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.GetThreadMentionsForUserPerChannel")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
defer func() {
|
||||
s.Root.Store.SetContext(origCtx)
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.GetThreadsForUser")
|
||||
|
||||
@@ -8518,6 +8518,26 @@ func (s *RetryLayerThreadStore) GetThreadForUser(userId string, teamId string, t
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
if !isRepeatableError(err) {
|
||||
return result, err
|
||||
}
|
||||
tries++
|
||||
if tries >= 3 {
|
||||
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
|
||||
return result, err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
|
||||
|
||||
tries := 0
|
||||
|
||||
@@ -108,6 +108,35 @@ func (s *SqlThreadStore) Get(id string) (*model.Thread, error) {
|
||||
return &thread, nil
|
||||
}
|
||||
|
||||
func (s *SqlThreadStore) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, error) {
|
||||
type Count struct {
|
||||
UnreadMentions int64
|
||||
ChannelId string
|
||||
}
|
||||
var counts []Count
|
||||
|
||||
sql, args, _ := s.getQueryBuilder().
|
||||
Select("SUM(UnreadMentions) as UnreadMentions", "ChannelId").
|
||||
From("ThreadMemberships").
|
||||
LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
|
||||
LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
|
||||
Where(sq.And{
|
||||
sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
|
||||
sq.Eq{"ThreadMemberships.UserId": userId},
|
||||
sq.Eq{"ThreadMemberships.Following": true},
|
||||
}).
|
||||
GroupBy("Threads.ChannelId").ToSql()
|
||||
|
||||
if _, err := s.GetMaster().Select(&counts, sql, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := map[string]int64{}
|
||||
for _, count := range counts {
|
||||
result[count.ChannelId] = count.UnreadMentions
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
|
||||
type JoinedThread struct {
|
||||
PostId string
|
||||
|
||||
@@ -255,6 +255,7 @@ type ThreadStore interface {
|
||||
GetThreadForUser(userId, teamId, threadId string, extended bool) (*model.ThreadResponse, error)
|
||||
Delete(postId string) error
|
||||
GetPosts(threadId string, since int64) ([]*model.Post, error)
|
||||
GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, error)
|
||||
|
||||
MarkAllAsRead(userId, teamId string) error
|
||||
MarkAsRead(userId, threadId string, timestamp int64) error
|
||||
|
||||
@@ -194,6 +194,29 @@ func (_m *ThreadStore) GetThreadForUser(userId string, teamId string, threadId s
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetThreadMentionsForUserPerChannel provides a mock function with given fields: userId, teamId
|
||||
func (_m *ThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
|
||||
ret := _m.Called(userId, teamId)
|
||||
|
||||
var r0 map[string]int64
|
||||
if rf, ok := ret.Get(0).(func(string, string) map[string]int64); ok {
|
||||
r0 = rf(userId, teamId)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string]int64)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string, string) error); ok {
|
||||
r1 = rf(userId, teamId)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetThreadsForUser provides a mock function with given fields: userId, teamId, opts
|
||||
func (_m *ThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
|
||||
ret := _m.Called(userId, teamId, opts)
|
||||
|
||||
@@ -40,7 +40,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
|
||||
ChannelId: c.Id,
|
||||
UserId: u1.Id,
|
||||
NotifyProps: model.GetDefaultChannelNotifyProps(),
|
||||
MsgCount: 90,
|
||||
MsgCount: 0,
|
||||
})
|
||||
require.NoError(t, err44)
|
||||
o := model.Post{}
|
||||
|
||||
@@ -7080,6 +7080,22 @@ func (s *TimerLayerThreadStore) GetThreadForUser(userId string, teamId string, t
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
|
||||
start := timemodule.Now()
|
||||
|
||||
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
|
||||
|
||||
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
success := "false"
|
||||
if err == nil {
|
||||
success = "true"
|
||||
}
|
||||
s.Root.Metrics.ObserveStoreMethodDuration("ThreadStore.GetThreadMentionsForUserPerChannel", success, elapsed)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
|
||||
start := timemodule.Now()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user