MM-32525 Incorrect number of mentions for channels when threads are enabled (#16853)

This commit is contained in:
Eli Yukelzon
2021-02-09 12:03:32 +02:00
committed by GitHub
parent 78b82769ca
commit 9a33c3706a
13 changed files with 195 additions and 12 deletions

View File

@@ -94,6 +94,7 @@ func (api *API) InitUser() {
api.BaseRoutes.UserThreads.Handle("", api.ApiSessionRequired(getThreadsForUser)).Methods("GET")
api.BaseRoutes.UserThreads.Handle("/read", api.ApiSessionRequired(updateReadStateAllThreadsByUser)).Methods("PUT")
api.BaseRoutes.UserThreads.Handle("/mention_counts", api.ApiSessionRequired(getMentionCountsForAllThreadsByUser)).Methods("GET")
api.BaseRoutes.UserThread.Handle("", api.ApiSessionRequired(getThreadForUser)).Methods("GET")
api.BaseRoutes.UserThread.Handle("/following", api.ApiSessionRequired(followThreadByUser)).Methods("PUT")
@@ -2842,6 +2843,26 @@ func getThreadForUser(c *Context, w http.ResponseWriter, r *http.Request) {
w.Write([]byte(threads.ToJson()))
}
func getMentionCountsForAllThreadsByUser(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequireUserId().RequireTeamId()
if c.Err != nil {
return
}
if !c.App.SessionHasPermissionToUser(*c.App.Session(), c.Params.UserId) {
c.SetPermissionError(model.PERMISSION_EDIT_OTHER_USERS)
return
}
counts, err := c.App.GetThreadMentionsForUserPerChannel(c.Params.UserId, c.Params.TeamId)
if err != nil {
c.Err = err
return
}
resp, _ := json.Marshal(counts)
w.Write(resp)
}
func getThreadsForUser(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequireUserId().RequireTeamId()
if c.Err != nil {

View File

@@ -5754,56 +5754,66 @@ func TestMaintainUnreadMentionsInThread(t *testing.T) {
*cfg.ServiceSettings.ThreadAutoFollow = true
*cfg.ServiceSettings.CollapsedThreads = model.COLLAPSED_THREADS_DEFAULT_ON
})
checkThreadList := func(client *model.Client4, userId string, expectedThreads int) (*model.Threads, *model.Response) {
checkMentionCounts := func(client *model.Client4, userId string, expected map[string]int64) {
actual, resp2 := client.GetThreadMentionsForUserPerChannel(userId, th.BasicTeam.Id)
CheckNoError(t, resp2)
require.EqualValues(t, expected, actual)
}
checkThreadList := func(client *model.Client4, userId string, expectedMentions, expectedThreads int) (*model.Threads, *model.Response) {
uss, resp := client.GetUserThreads(userId, th.BasicTeam.Id, model.GetUserThreadsOpts{
Deleted: false,
})
CheckNoError(t, resp)
require.Len(t, uss.Threads, expectedThreads)
sum := int64(0)
for _, thr := range uss.Threads {
sum += thr.UnreadMentions
}
require.Equal(t, sum, uss.TotalUnreadMentions)
require.EqualValues(t, expectedMentions, uss.TotalUnreadMentions)
return uss, resp
}
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.BasicUser.Id)
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
// create regular post
rpost, _ := postAndCheck(t, Client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testMsg"})
// create reply and mention the original poster and another user
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply @" + th.BasicUser.Username + " and @" + th.BasicUser2.Username, RootId: rpost.Id})
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.BasicUser.Id)
defer th.App.Srv().Store.Post().PermanentDeleteByUser(th.SystemAdminUser.Id)
checkMentionCounts(Client, th.BasicUser.Id, map[string]int64{th.BasicChannel.Id: 1})
// basic user 1 was mentioned 1 time
checkThreadList(th.Client, th.BasicUser.Id, 1)
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
// basic user 2 was mentioned 1 time
checkThreadList(th.SystemAdminClient, th.BasicUser2.Id, 1)
checkThreadList(th.SystemAdminClient, th.BasicUser2.Id, 1, 1)
// test self mention, shouldn't increase mention count
postAndCheck(t, Client, &model.Post{ChannelId: th.BasicChannel.Id, Message: "testReply @" + th.BasicUser.Username, RootId: rpost.Id})
// count should increase
checkThreadList(th.Client, th.BasicUser.Id, 1)
// count shouldn't increase
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
// test DM
dm := th.CreateDmChannel(th.SystemAdminUser)
dm_root_post, _ := postAndCheck(t, Client, &model.Post{ChannelId: dm.Id, Message: "hi @" + th.SystemAdminUser.Username})
// no changes
checkThreadList(th.Client, th.BasicUser.Id, 1)
checkThreadList(th.Client, th.BasicUser.Id, 1, 1)
// post reply by the same user
postAndCheck(t, Client, &model.Post{ChannelId: dm.Id, Message: "how are you", RootId: dm_root_post.Id})
// thread created
checkThreadList(th.Client, th.BasicUser.Id, 2)
checkThreadList(th.Client, th.BasicUser.Id, 1, 2)
// post two replies by another user, without mentions. mention count should still increase since this is a DM
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: dm.Id, Message: "msg1", RootId: dm_root_post.Id})
postAndCheck(t, th.SystemAdminClient, &model.Post{ChannelId: dm.Id, Message: "msg2", RootId: dm_root_post.Id})
// expect increment by two mentions
checkThreadList(th.Client, th.BasicUser.Id, 2)
checkThreadList(th.Client, th.BasicUser.Id, 3, 2)
checkMentionCounts(Client, th.BasicUser.Id, map[string]int64{th.BasicChannel.Id: 1, dm.Id: 2})
}
func TestReadThreads(t *testing.T) {

View File

@@ -699,6 +699,7 @@ type AppIface interface {
GetTermsOfService(id string) (*model.TermsOfService, *model.AppError)
GetThreadForUser(userID, teamID, threadId string, extended bool) (*model.ThreadResponse, *model.AppError)
GetThreadMembershipsForUser(userID, teamID string) ([]*model.ThreadMembership, error)
GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *model.AppError)
GetThreadsForUser(userID, teamID string, options model.GetUserThreadsOpts) (*model.Threads, *model.AppError)
GetUploadSession(uploadId string) (*model.UploadSession, *model.AppError)
GetUploadSessionsForUser(userID string) ([]*model.UploadSession, *model.AppError)

View File

@@ -8612,6 +8612,28 @@ func (a *OpenTracingAppLayer) GetThreadMembershipsForUser(userID string, teamID
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetThreadMentionsForUserPerChannel")
a.ctx = newCtx
a.app.Srv().Store.SetContext(newCtx)
defer func() {
a.app.Srv().Store.SetContext(origCtx)
a.ctx = origCtx
}()
defer span.Finish()
resultVar0, resultVar1 := a.app.GetThreadMentionsForUserPerChannel(userId, teamId)
if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1))
ext.Error.Set(span, true)
}
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) GetThreadsForUser(userID string, teamID string, options model.GetUserThreadsOpts) (*model.Threads, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetThreadsForUser")

View File

@@ -2368,6 +2368,14 @@ func (a *App) GetThreadsForUser(userID, teamID string, options model.GetUserThre
return threads, nil
}
func (a *App) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *model.AppError) {
res, err := a.Srv().Store.Thread().GetThreadMentionsForUserPerChannel(userId, teamId)
if err != nil {
return nil, model.NewAppError("GetThreadMentionsForUserPerChannel", "app.user.get_threads_for_user.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return res, nil
}
func (a *App) GetThreadForUser(userID, teamID, threadId string, extended bool) (*model.ThreadResponse, *model.AppError) {
thread, err := a.Srv().Store.Thread().GetThreadForUser(userID, teamID, threadId, extended)
if err != nil {

View File

@@ -5816,6 +5816,20 @@ func (c *Client4) ListImports() ([]string, *Response) {
return ArrayFromJson(r.Body), BuildResponse(r)
}
func (c *Client4) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, *Response) {
url := c.GetUserThreadsRoute(userId, teamId)
r, appErr := c.DoApiGet(url+"/mention_counts", "")
if appErr != nil {
return nil, BuildErrorResponse(r, appErr)
}
defer closeBody(r)
var counts map[string]int64
json.NewDecoder(r.Body).Decode(&counts)
return counts, BuildResponse(r)
}
func (c *Client4) GetUserThreads(userId, teamId string, options GetUserThreadsOpts) (*Threads, *Response) {
v := url.Values{}
if options.Since != 0 {

View File

@@ -7846,6 +7846,24 @@ func (s *OpenTracingLayerThreadStore) GetThreadForUser(userId string, teamId str
return result, err
}
func (s *OpenTracingLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.GetThreadMentionsForUserPerChannel")
s.Root.Store.SetContext(newCtx)
defer func() {
s.Root.Store.SetContext(origCtx)
}()
defer span.Finish()
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
return result, err
}
func (s *OpenTracingLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.GetThreadsForUser")

View File

@@ -8518,6 +8518,26 @@ func (s *RetryLayerThreadStore) GetThreadForUser(userId string, teamId string, t
}
func (s *RetryLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
tries := 0
for {
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
if err == nil {
return result, nil
}
if !isRepeatableError(err) {
return result, err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
return result, err
}
}
}
func (s *RetryLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
tries := 0

View File

@@ -108,6 +108,35 @@ func (s *SqlThreadStore) Get(id string) (*model.Thread, error) {
return &thread, nil
}
func (s *SqlThreadStore) GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, error) {
type Count struct {
UnreadMentions int64
ChannelId string
}
var counts []Count
sql, args, _ := s.getQueryBuilder().
Select("SUM(UnreadMentions) as UnreadMentions", "ChannelId").
From("ThreadMemberships").
LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
Where(sq.And{
sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
sq.Eq{"ThreadMemberships.UserId": userId},
sq.Eq{"ThreadMemberships.Following": true},
}).
GroupBy("Threads.ChannelId").ToSql()
if _, err := s.GetMaster().Select(&counts, sql, args...); err != nil {
return nil, err
}
result := map[string]int64{}
for _, count := range counts {
result[count.ChannelId] = count.UnreadMentions
}
return result, nil
}
func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
type JoinedThread struct {
PostId string

View File

@@ -255,6 +255,7 @@ type ThreadStore interface {
GetThreadForUser(userId, teamId, threadId string, extended bool) (*model.ThreadResponse, error)
Delete(postId string) error
GetPosts(threadId string, since int64) ([]*model.Post, error)
GetThreadMentionsForUserPerChannel(userId, teamId string) (map[string]int64, error)
MarkAllAsRead(userId, teamId string) error
MarkAsRead(userId, threadId string, timestamp int64) error

View File

@@ -194,6 +194,29 @@ func (_m *ThreadStore) GetThreadForUser(userId string, teamId string, threadId s
return r0, r1
}
// GetThreadMentionsForUserPerChannel provides a mock function with given fields: userId, teamId
func (_m *ThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
ret := _m.Called(userId, teamId)
var r0 map[string]int64
if rf, ok := ret.Get(0).(func(string, string) map[string]int64); ok {
r0 = rf(userId, teamId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]int64)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(userId, teamId)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetThreadsForUser provides a mock function with given fields: userId, teamId, opts
func (_m *ThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
ret := _m.Called(userId, teamId, opts)

View File

@@ -40,7 +40,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
ChannelId: c.Id,
UserId: u1.Id,
NotifyProps: model.GetDefaultChannelNotifyProps(),
MsgCount: 90,
MsgCount: 0,
})
require.NoError(t, err44)
o := model.Post{}

View File

@@ -7080,6 +7080,22 @@ func (s *TimerLayerThreadStore) GetThreadForUser(userId string, teamId string, t
return result, err
}
func (s *TimerLayerThreadStore) GetThreadMentionsForUserPerChannel(userId string, teamId string) (map[string]int64, error) {
start := timemodule.Now()
result, err := s.ThreadStore.GetThreadMentionsForUserPerChannel(userId, teamId)
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
if s.Root.Metrics != nil {
success := "false"
if err == nil {
success = "true"
}
s.Root.Metrics.ObserveStoreMethodDuration("ThreadStore.GetThreadMentionsForUserPerChannel", success, elapsed)
}
return result, err
}
func (s *TimerLayerThreadStore) GetThreadsForUser(userId string, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
start := timemodule.Now()