diff --git a/api4/team.go b/api4/team.go index 44ea08501f..559bda9fef 100644 --- a/api4/team.go +++ b/api4/team.go @@ -420,8 +420,9 @@ func getTeamsUnreadForUser(c *Context, w http.ResponseWriter, r *http.Request) { // optional team id to be excluded from the result teamId := r.URL.Query().Get("exclude_team") + includeCollapsedThreads := r.URL.Query().Get("include_collapsed_threads") == "true" - unreadTeamsList, err := c.App.GetTeamsUnreadForUser(teamId, c.Params.UserId) + unreadTeamsList, err := c.App.GetTeamsUnreadForUser(teamId, c.Params.UserId, includeCollapsedThreads) if err != nil { c.Err = err return diff --git a/api4/team_test.go b/api4/team_test.go index 60e097c9df..f8978d66e3 100644 --- a/api4/team_test.go +++ b/api4/team_test.go @@ -2651,22 +2651,22 @@ func TestGetMyTeamsUnread(t *testing.T) { user := th.BasicUser Client.Login(user.Email, user.Password) - teams, resp := Client.GetTeamsUnreadForUser(user.Id, "") + teams, resp := Client.GetTeamsUnreadForUser(user.Id, "", true) CheckNoError(t, resp) require.NotEqual(t, len(teams), 0, "should have results") - teams, resp = Client.GetTeamsUnreadForUser(user.Id, th.BasicTeam.Id) + teams, resp = Client.GetTeamsUnreadForUser(user.Id, th.BasicTeam.Id, true) CheckNoError(t, resp) require.Empty(t, teams, "should not have results") - _, resp = Client.GetTeamsUnreadForUser("fail", "") + _, resp = Client.GetTeamsUnreadForUser("fail", "", true) CheckBadRequestStatus(t, resp) - _, resp = Client.GetTeamsUnreadForUser(model.NewId(), "") + _, resp = Client.GetTeamsUnreadForUser(model.NewId(), "", true) CheckForbiddenStatus(t, resp) Client.Logout() - _, resp = Client.GetTeamsUnreadForUser(user.Id, "") + _, resp = Client.GetTeamsUnreadForUser(user.Id, "", true) CheckUnauthorizedStatus(t, resp) } diff --git a/api4/user_test.go b/api4/user_test.go index 49c53edc01..0087addb87 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -5551,7 +5551,7 @@ func TestGetThreadsForUser(t *testing.T) { require.Nil(t, resp.Error) require.Len(t, uss.Threads, 10) - require.Equal(t, uss.Threads[0].PostId, rootIdBefore) + require.Equal(t, rootIdBefore, uss.Threads[0].PostId) uss2, resp2 := th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ Deleted: false, @@ -5561,7 +5561,7 @@ func TestGetThreadsForUser(t *testing.T) { require.Nil(t, resp2.Error) require.Len(t, uss2.Threads, 10) - require.Equal(t, uss2.Threads[0].PostId, rootIdAfter) + require.Equal(t, rootIdAfter, uss2.Threads[0].PostId) uss3, resp3 := th.Client.GetUserThreads(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{ Deleted: false, diff --git a/app/app_iface.go b/app/app_iface.go index 5e87c74c01..a2ff74d465 100644 --- a/app/app_iface.go +++ b/app/app_iface.go @@ -762,7 +762,7 @@ type AppIface interface { GetTeamsForScheme(scheme *model.Scheme, offset int, limit int) ([]*model.Team, *model.AppError) GetTeamsForSchemePage(scheme *model.Scheme, page int, perPage int) ([]*model.Team, *model.AppError) GetTeamsForUser(userID string) ([]*model.Team, *model.AppError) - GetTeamsUnreadForUser(excludeTeamId string, userID string) ([]*model.TeamUnread, *model.AppError) + GetTeamsUnreadForUser(excludeTeamId string, userID string, includeCollapsedThreads bool) ([]*model.TeamUnread, *model.AppError) GetTermsOfService(id string) (*model.TermsOfService, *model.AppError) GetThreadForUser(teamID string, threadMembership *model.ThreadMembership, extended bool) (*model.ThreadResponse, *model.AppError) GetThreadMembershipForUser(userId, threadId string) (*model.ThreadMembership, *model.AppError) diff --git a/app/opentracing/opentracing_layer.go b/app/opentracing/opentracing_layer.go index 510c681c54..0ddc2ba788 100644 --- a/app/opentracing/opentracing_layer.go +++ b/app/opentracing/opentracing_layer.go @@ -9250,7 +9250,7 @@ func (a *OpenTracingAppLayer) GetTeamsForUser(userID string) ([]*model.Team, *mo return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) GetTeamsUnreadForUser(excludeTeamId string, userID string) ([]*model.TeamUnread, *model.AppError) { +func (a *OpenTracingAppLayer) GetTeamsUnreadForUser(excludeTeamId string, userID string, includeCollapsedThreads bool) ([]*model.TeamUnread, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetTeamsUnreadForUser") @@ -9262,7 +9262,7 @@ func (a *OpenTracingAppLayer) GetTeamsUnreadForUser(excludeTeamId string, userID }() defer span.Finish() - resultVar0, resultVar1 := a.app.GetTeamsUnreadForUser(excludeTeamId, userID) + resultVar0, resultVar1 := a.app.GetTeamsUnreadForUser(excludeTeamId, userID, includeCollapsedThreads) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) diff --git a/app/plugin_api.go b/app/plugin_api.go index 949e8c9946..ec5e9ac963 100644 --- a/app/plugin_api.go +++ b/app/plugin_api.go @@ -183,7 +183,7 @@ func (api *PluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) } func (api *PluginAPI) GetTeamsUnreadForUser(userID string) ([]*model.TeamUnread, *model.AppError) { - return api.app.GetTeamsUnreadForUser("", userID) + return api.app.GetTeamsUnreadForUser("", userID, false) } func (api *PluginAPI) UpdateTeam(team *model.Team) (*model.Team, *model.AppError) { diff --git a/app/team.go b/app/team.go index 4119dc689d..fefa1e0051 100644 --- a/app/team.go +++ b/app/team.go @@ -1654,7 +1654,7 @@ func (a *App) FindTeamByName(name string) bool { return true } -func (a *App) GetTeamsUnreadForUser(excludeTeamId string, userID string) ([]*model.TeamUnread, *model.AppError) { +func (a *App) GetTeamsUnreadForUser(excludeTeamId string, userID string, includeCollapsedThreads bool) ([]*model.TeamUnread, *model.AppError) { data, err := a.Srv().Store.Team().GetChannelUnreadsForAllTeams(excludeTeamId, userID) if err != nil { return nil, model.NewAppError("GetTeamsUnreadForUser", "app.team.get_unread.app_error", nil, err.Error(), http.StatusInternalServerError) @@ -1681,17 +1681,29 @@ func (a *App) GetTeamsUnreadForUser(excludeTeamId string, userID string) ([]*mod membersMap[id] = unreads(data[i], mu) } else { membersMap[id] = unreads(data[i], &model.TeamUnread{ - MsgCount: 0, - MentionCount: 0, - MentionCountRoot: 0, - MsgCountRoot: 0, - TeamId: id, + MsgCount: 0, + MentionCount: 0, + MentionCountRoot: 0, + MsgCountRoot: 0, + ThreadCount: 0, + ThreadMentionCount: 0, + TeamId: id, }) } } - for _, val := range membersMap { - members = append(members, val) + includeCollapsedThreads = includeCollapsedThreads && *a.Config().ServiceSettings.CollapsedThreads != model.COLLAPSED_THREADS_DISABLED + + for _, member := range membersMap { + if includeCollapsedThreads { + data, err := a.Srv().Store.Thread().GetThreadsForUser(userID, member.TeamId, model.GetUserThreadsOpts{TotalsOnly: true, TeamOnly: true}) + if err != nil { + return nil, model.NewAppError("GetTeamsUnreadForUser", "app.team.get_unread.app_error", nil, err.Error(), http.StatusInternalServerError) + } + member.ThreadCount = data.TotalUnreadThreads + member.ThreadMentionCount = data.TotalUnreadMentions + } + members = append(members, member) } return members, nil diff --git a/model/client4.go b/model/client4.go index c25bdea523..ce73cbe22a 100644 --- a/model/client4.go +++ b/model/client4.go @@ -1445,14 +1445,20 @@ func (c *Client4) AttachDeviceId(deviceId string) (bool, *Response) { // GetTeamsUnreadForUser will return an array with TeamUnread objects that contain the amount // of unread messages and mentions the current user has for the teams it belongs to. -// An optional team ID can be set to exclude that team from the results. Must be authenticated. -func (c *Client4) GetTeamsUnreadForUser(userId, teamIdToExclude string) ([]*TeamUnread, *Response) { - var optional string +// An optional team ID can be set to exclude that team from the results. +// An optional boolean can be set to include collapsed thread unreads. Must be authenticated. +func (c *Client4) GetTeamsUnreadForUser(userId, teamIdToExclude string, includeCollapsedThreads bool) ([]*TeamUnread, *Response) { + query := url.Values{} + if teamIdToExclude != "" { - optional += fmt.Sprintf("?exclude_team=%s", url.QueryEscape(teamIdToExclude)) + query.Set("exclude_team", teamIdToExclude) } - r, err := c.DoApiGet(c.GetUserRoute(userId)+"/teams/unread"+optional, "") + if includeCollapsedThreads { + query.Set("include_collapsed_threads", "true") + } + + r, err := c.DoApiGet(c.GetUserRoute(userId)+"/teams/unread?"+query.Encode(), "") if err != nil { return nil, BuildErrorResponse(r, err) } diff --git a/model/team_member.go b/model/team_member.go index e8a82c3e2b..f5f1cc6166 100644 --- a/model/team_member.go +++ b/model/team_member.go @@ -31,11 +31,13 @@ type TeamMember struct { //msgp:ignore TeamUnread type TeamUnread struct { - TeamId string `json:"team_id"` - MsgCount int64 `json:"msg_count"` - MentionCount int64 `json:"mention_count"` - MentionCountRoot int64 `json:"mention_count_root"` - MsgCountRoot int64 `json:"msg_count_root"` + TeamId string `json:"team_id"` + MsgCount int64 `json:"msg_count"` + MentionCount int64 `json:"mention_count"` + MentionCountRoot int64 `json:"mention_count_root"` + MsgCountRoot int64 `json:"msg_count_root"` + ThreadCount int64 `json:"thread_count"` + ThreadMentionCount int64 `json:"thread_mention_count"` } //msgp:ignore TeamMemberForExport diff --git a/model/thread.go b/model/thread.go index 9a6b5fab64..fe4a40146a 100644 --- a/model/thread.go +++ b/model/thread.go @@ -54,6 +54,12 @@ type GetUserThreadsOpts struct { // Unread will make sure that only threads with unread replies are returned Unread bool + + // TotalsOnly will not fetch any threads and just fetch the total counts + TotalsOnly bool + + // TeamOnly will only fetch threads and unreads for the specified team and excludes DMs/GMs + TeamOnly bool } func (o *ThreadResponse) ToJson() string { diff --git a/store/sqlstore/thread_store.go b/store/sqlstore/thread_store.go index ba404a30c5..d27c4e0750 100644 --- a/store/sqlstore/thread_store.go +++ b/store/sqlstore/thread_store.go @@ -131,10 +131,20 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get } fetchConditions := sq.And{ - sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, sq.Eq{"ThreadMemberships.UserId": userId}, sq.Eq{"ThreadMemberships.Following": true}, } + if opts.TeamOnly { + fetchConditions = sq.And{ + sq.Eq{"Channels.TeamId": teamId}, + fetchConditions, + } + } else { + fetchConditions = sq.And{ + sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, + fetchConditions, + } + } if !opts.Deleted { fetchConditions = sq.And{ fetchConditions, @@ -150,7 +160,11 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get totalUnreadThreadsChan := make(chan store.StoreResult, 1) totalCountChan := make(chan store.StoreResult, 1) totalUnreadMentionsChan := make(chan store.StoreResult, 1) - threadsChan := make(chan store.StoreResult, 1) + var threadsChan chan store.StoreResult + if !opts.TotalsOnly { + threadsChan = make(chan store.StoreResult, 1) + } + go func() { repliesQuery, repliesQueryArgs, _ := s.getQueryBuilder(). Select("COUNT(DISTINCT(Posts.RootId))"). @@ -195,71 +209,68 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get totalUnreadMentionsChan <- store.StoreResult{Data: totalUnreadMentions, NErr: err} close(totalUnreadMentionsChan) }() - go func() { - newFetchConditions := fetchConditions - if opts.Since > 0 { - newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}} - } - order := "DESC" - if opts.Before != "" { - newFetchConditions = sq.And{ - newFetchConditions, - sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before), - } - } - if opts.After != "" { - order = "ASC" - newFetchConditions = sq.And{ - newFetchConditions, - sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After), - } - } - if opts.Unread { - newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} - } - unreadRepliesFetchConditions := sq.And{ - sq.Expr("Posts.RootId = ThreadMemberships.PostId"), - sq.Expr("Posts.CreateAt > ThreadMemberships.LastViewed"), - } - if !opts.Deleted { - unreadRepliesFetchConditions = sq.And{ - unreadRepliesFetchConditions, - sq.Expr("Posts.DeleteAt = 0"), + if !opts.TotalsOnly { + go func() { + newFetchConditions := fetchConditions + if opts.Since > 0 { + newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}} + } + order := "DESC" + if opts.Before != "" { + newFetchConditions = sq.And{ + newFetchConditions, + sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before), + } + } + if opts.After != "" { + order = "ASC" + newFetchConditions = sq.And{ + newFetchConditions, + sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After), + } + } + if opts.Unread { + newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} } - } - unreadRepliesQuery, _ := sq. - Select("COUNT(Posts.Id)"). - From("Posts"). - Where(unreadRepliesFetchConditions). - MustSql() + unreadRepliesFetchConditions := sq.And{ + sq.Expr("Posts.RootId = ThreadMemberships.PostId"), + sq.Expr("Posts.CreateAt > ThreadMemberships.LastViewed"), + } + if !opts.Deleted { + unreadRepliesFetchConditions = sq.And{ + unreadRepliesFetchConditions, + sq.Expr("Posts.DeleteAt = 0"), + } + } - var threads []*JoinedThread - query, args, _ := s.getQueryBuilder(). - Select(`Threads.*, + unreadRepliesQuery, _ := sq. + Select("COUNT(Posts.Id)"). + From("Posts"). + Where(unreadRepliesFetchConditions). + MustSql() + + var threads []*JoinedThread + query, args, _ := s.getQueryBuilder(). + Select(`Threads.*, ` + postSliceCoalesceQuery() + `, ThreadMemberships.LastViewed as LastViewedAt, ThreadMemberships.UnreadMentions as UnreadMentions`). - From("Threads"). - Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). - LeftJoin("Posts ON Posts.Id = Threads.PostId"). - LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). - LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId"). - Where(newFetchConditions). - OrderBy("Threads.LastReplyAt " + order). - Limit(pageSize).ToSql() + From("Threads"). + Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). + LeftJoin("Posts ON Posts.Id = Threads.PostId"). + LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). + LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId"). + Where(newFetchConditions). + OrderBy("Threads.LastReplyAt " + order). + Limit(pageSize).ToSql() - _, err := s.GetReplica().Select(&threads, query, args...) - threadsChan <- store.StoreResult{Data: threads, NErr: err} - close(threadsChan) - }() - - threadsResult := <-threadsChan - if threadsResult.NErr != nil { - return nil, threadsResult.NErr + _, err := s.GetReplica().Select(&threads, query, args...) + threadsChan <- store.StoreResult{Data: threads, NErr: err} + close(threadsChan) + }() } - threads := threadsResult.Data.([]*JoinedThread) totalUnreadMentionsResult := <-totalUnreadMentionsChan if totalUnreadMentionsResult.NErr != nil { @@ -281,26 +292,6 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get var userIds []string userIdMap := map[string]bool{} - for _, thread := range threads { - for _, participantId := range thread.Participants { - if _, ok := userIdMap[participantId]; !ok { - userIdMap[participantId] = true - userIds = append(userIds, participantId) - } - } - } - var users []*model.User - if opts.Extended { - var err error - users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true) - if err != nil { - return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId) - } - } else { - for _, userId := range userIds { - users = append(users, &model.User{Id: userId}) - } - } result := &model.Threads{ Total: totalCount, @@ -309,31 +300,59 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get TotalUnreadThreads: totalUnreadThreads, } - for _, thread := range threads { - var participants []*model.User - for _, participantId := range thread.Participants { - var participant *model.User - for _, u := range users { - if u.Id == participantId { - participant = u - break + if !opts.TotalsOnly { + threadsResult := <-threadsChan + if threadsResult.NErr != nil { + return nil, threadsResult.NErr + } + threads := threadsResult.Data.([]*JoinedThread) + for _, thread := range threads { + for _, participantId := range thread.Participants { + if _, ok := userIdMap[participantId]; !ok { + userIdMap[participantId] = true + userIds = append(userIds, participantId) } } - if participant == nil { - return nil, errors.New("cannot find thread participant with id=" + participantId) - } - participants = append(participants, participant) } - result.Threads = append(result.Threads, &model.ThreadResponse{ - PostId: thread.PostId, - ReplyCount: thread.ReplyCount, - LastReplyAt: thread.LastReplyAt, - LastViewedAt: thread.LastViewedAt, - UnreadReplies: thread.UnreadReplies, - UnreadMentions: thread.UnreadMentions, - Participants: participants, - Post: thread.Post.ToNilIfInvalid(), - }) + var users []*model.User + if opts.Extended { + var err error + users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true) + if err != nil { + return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId) + } + } else { + for _, userId := range userIds { + users = append(users, &model.User{Id: userId}) + } + } + + for _, thread := range threads { + var participants []*model.User + for _, participantId := range thread.Participants { + var participant *model.User + for _, u := range users { + if u.Id == participantId { + participant = u + break + } + } + if participant == nil { + return nil, errors.New("cannot find thread participant with id=" + participantId) + } + participants = append(participants, participant) + } + result.Threads = append(result.Threads, &model.ThreadResponse{ + PostId: thread.PostId, + ReplyCount: thread.ReplyCount, + LastReplyAt: thread.LastReplyAt, + LastViewedAt: thread.LastViewedAt, + UnreadReplies: thread.UnreadReplies, + UnreadMentions: thread.UnreadMentions, + Participants: participants, + Post: thread.Post.ToNilIfInvalid(), + }) + } } return result, nil