diff --git a/api4/channel.go b/api4/channel.go index dcf6d5b503..f7c7a72726 100644 --- a/api4/channel.go +++ b/api4/channel.go @@ -534,7 +534,23 @@ func getAllChannels(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte(channels.ToJson())) + var payload []byte + if c.Params.IncludeTotalCount { + totalCount, err := c.App.GetAllChannelsCount(opts) + if err != nil { + c.Err = err + return + } + cwc := &model.ChannelsWithCount{ + Channels: channels, + TotalCount: totalCount, + } + payload = cwc.ToJson() + } else { + payload = []byte(channels.ToJson()) + } + + w.Write(payload) } func getPublicChannelsForTeam(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/api4/channel_test.go b/api4/channel_test.go index 6c8edc0e62..fbb1ddff26 100644 --- a/api4/channel_test.go +++ b/api4/channel_test.go @@ -832,6 +832,37 @@ func TestGetAllChannels(t *testing.T) { CheckForbiddenStatus(t, resp) } +func TestGetAllChannelsWithCount(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + Client := th.Client + + channels, total, resp := th.SystemAdminClient.GetAllChannelsWithCount(0, 20, "") + CheckNoError(t, resp) + + // At least, all the not-deleted channels created during the InitBasic + require.True(t, len(*channels) >= 3) + for _, c := range *channels { + require.NotEqual(t, c.TeamId, "") + } + require.Equal(t, int64(5), total) + + channels, _, resp = th.SystemAdminClient.GetAllChannelsWithCount(0, 10, "") + CheckNoError(t, resp) + require.True(t, len(*channels) >= 3) + + channels, _, resp = th.SystemAdminClient.GetAllChannelsWithCount(1, 1, "") + CheckNoError(t, resp) + require.Len(t, *channels, 1) + + channels, _, resp = th.SystemAdminClient.GetAllChannelsWithCount(10000, 10000, "") + CheckNoError(t, resp) + require.Len(t, *channels, 0) + + _, _, resp = Client.GetAllChannelsWithCount(0, 20, "") + CheckForbiddenStatus(t, resp) +} + func TestSearchChannels(t *testing.T) { th := Setup().InitBasic() defer th.TearDown() diff --git a/app/channel.go b/app/channel.go index 3b4c4d2ef1..ddf98acd30 100644 --- a/app/channel.go +++ b/app/channel.go @@ -1243,6 +1243,18 @@ func (a *App) GetAllChannels(page, perPage int, opts model.ChannelSearchOpts) (* return a.Srv.Store.Channel().GetAllChannels(page*perPage, perPage, storeOpts) } +func (a *App) GetAllChannelsCount(opts model.ChannelSearchOpts) (int64, *model.AppError) { + if opts.ExcludeDefaultChannels { + opts.ExcludeChannelNames = a.DefaultChannelNames() + } + storeOpts := store.ChannelSearchOpts{ + ExcludeChannelNames: opts.ExcludeChannelNames, + NotAssociatedToGroup: opts.NotAssociatedToGroup, + IncludeDeleted: opts.IncludeDeleted, + } + return a.Srv.Store.Channel().GetAllChannelsCount(storeOpts) +} + func (a *App) GetDeletedChannels(teamId string, offset int, limit int) (*model.ChannelList, *model.AppError) { return a.Srv.Store.Channel().GetDeleted(teamId, offset, limit) } diff --git a/model/channel.go b/model/channel.go index 07f6f6a6f1..d0d1b8fc38 100644 --- a/model/channel.go +++ b/model/channel.go @@ -60,6 +60,11 @@ type ChannelWithTeamData struct { TeamUpdateAt int64 `json:"team_update_at"` } +type ChannelsWithCount struct { + Channels *ChannelListWithTeamData `json:"channels"` + TotalCount int64 `json:"total_count"` +} + type ChannelPatch struct { DisplayName *string `json:"display_name"` Name *string `json:"name"` @@ -111,6 +116,17 @@ func (o *ChannelPatch) ToJson() string { return string(b) } +func (o *ChannelsWithCount) ToJson() []byte { + b, _ := json.Marshal(o) + return b +} + +func ChannelsWithCountFromJson(data io.Reader) *ChannelsWithCount { + var o *ChannelsWithCount + json.NewDecoder(data).Decode(&o) + return o +} + func ChannelFromJson(data io.Reader) *Channel { var o *Channel json.NewDecoder(data).Decode(&o) diff --git a/model/client4.go b/model/client4.go index b10c35970d..7bcefc8b30 100644 --- a/model/client4.go +++ b/model/client4.go @@ -1908,6 +1908,18 @@ func (c *Client4) GetAllChannels(page int, perPage int, etag string) (*ChannelLi return ChannelListWithTeamDataFromJson(r.Body), BuildResponse(r) } +// GetAllChannelsWithCount get all the channels including the total count. Must be a system administrator. +func (c *Client4) GetAllChannelsWithCount(page int, perPage int, etag string) (*ChannelListWithTeamData, int64, *Response) { + query := fmt.Sprintf("?page=%v&per_page=%v&include_total_count=true", page, perPage) + r, err := c.DoApiGet(c.GetChannelsRoute()+query, etag) + if err != nil { + return nil, 0, BuildErrorResponse(r, err) + } + defer closeBody(r) + cwc := ChannelsWithCountFromJson(r.Body) + return cwc.Channels, cwc.TotalCount, BuildResponse(r) +} + // CreateChannel creates a channel based on the provided channel struct. func (c *Client4) CreateChannel(channel *Channel) (*Channel, *Response) { r, err := c.DoApiPost(c.GetChannelsRoute(), channel.ToJson()) diff --git a/store/sqlstore/channel_store.go b/store/sqlstore/channel_store.go index 12faa7229b..7bb774cc9c 100644 --- a/store/sqlstore/channel_store.go +++ b/store/sqlstore/channel_store.go @@ -948,27 +948,10 @@ func (s SqlChannelStore) GetChannels(teamId string, userId string, includeDelete return channels, nil } -func (s SqlChannelStore) GetAllChannels(offset int, limit int, opts store.ChannelSearchOpts) (*model.ChannelListWithTeamData, *model.AppError) { - query := s.getQueryBuilder(). - Select("c.*, Teams.DisplayName AS TeamDisplayName, Teams.Name AS TeamName, Teams.UpdateAt AS TeamUpdateAt"). - From("Channels AS c"). - Join("Teams ON Teams.Id = c.TeamId"). - Where(sq.Eq{"c.Type": []string{model.CHANNEL_PRIVATE, model.CHANNEL_OPEN}}). - OrderBy("c.DisplayName, Teams.DisplayName"). - Limit(uint64(limit)). - Offset(uint64(offset)) +func (s SqlChannelStore) GetAllChannels(offset, limit int, opts store.ChannelSearchOpts) (*model.ChannelListWithTeamData, *model.AppError) { + query := s.getAllChannelsQuery(opts, false) - if !opts.IncludeDeleted { - query = query.Where(sq.Eq{"c.DeleteAt": int(0)}) - } - - if len(opts.NotAssociatedToGroup) > 0 { - query = query.Where("c.Id NOT IN (SELECT ChannelId FROM GroupChannels WHERE GroupChannels.GroupId = ? AND GroupChannels.DeleteAt = 0)", opts.NotAssociatedToGroup) - } - - if len(opts.ExcludeChannelNames) > 0 { - query = query.Where(fmt.Sprintf("c.Name NOT IN ('%s')", strings.Join(opts.ExcludeChannelNames, "', '"))) - } + query = query.OrderBy("c.DisplayName, Teams.DisplayName").Limit(uint64(limit)).Offset(uint64(offset)) queryString, args, err := query.ToSql() if err != nil { @@ -985,6 +968,54 @@ func (s SqlChannelStore) GetAllChannels(offset int, limit int, opts store.Channe return data, nil } +func (s SqlChannelStore) GetAllChannelsCount(opts store.ChannelSearchOpts) (int64, *model.AppError) { + query := s.getAllChannelsQuery(opts, true) + + queryString, args, err := query.ToSql() + if err != nil { + return 0, model.NewAppError("SqlChannelStore.GetAllChannelsCount", "store.sql.build_query.app_error", nil, err.Error(), http.StatusInternalServerError) + } + + count, err := s.GetReplica().SelectInt(queryString, args...) + if err != nil { + return 0, model.NewAppError("SqlChannelStore.GetAllChannelsCount", "store.sql_channel.get_all_channels.get.app_error", nil, err.Error(), http.StatusInternalServerError) + } + + return count, nil +} + +func (s SqlChannelStore) getAllChannelsQuery(opts store.ChannelSearchOpts, forCount bool) sq.SelectBuilder { + var selectStr string + if forCount { + selectStr = "count(c.Id)" + } else { + selectStr = "c.*, Teams.DisplayName AS TeamDisplayName, Teams.Name AS TeamName, Teams.UpdateAt AS TeamUpdateAt" + } + + query := s.getQueryBuilder(). + Select(selectStr). + From("Channels AS c"). + Where(sq.Eq{"c.Type": []string{model.CHANNEL_PRIVATE, model.CHANNEL_OPEN}}) + + if !forCount { + query = query.Join("Teams ON Teams.Id = c.TeamId") + } + + if !opts.IncludeDeleted { + query = query.Where(sq.Eq{"c.DeleteAt": int(0)}) + } + + if len(opts.NotAssociatedToGroup) > 0 { + query = query.Where("c.Id NOT IN (SELECT ChannelId FROM GroupChannels WHERE GroupChannels.GroupId = ? AND GroupChannels.DeleteAt = 0)", opts.NotAssociatedToGroup) + } + + if len(opts.ExcludeChannelNames) > 0 { + query = query.Where(fmt.Sprintf("c.Name NOT IN ('%s')", strings.Join(opts.ExcludeChannelNames, "', '"))) + } + + return query +} + func (s SqlChannelStore) GetMoreChannels(teamId string, userId string, offset int, limit int) (*model.ChannelList, *model.AppError) { channels := &model.ChannelList{} _, err := s.GetReplica().Select(channels, ` diff --git a/store/sqlstore/group_store.go b/store/sqlstore/group_store.go index b4a3663048..83e89c6dff 100644 --- a/store/sqlstore/group_store.go +++ b/store/sqlstore/group_store.go @@ -9,7 +9,6 @@ import ( "net/http" "strings" - "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel" "github.com/mattermost/mattermost-server/model" @@ -900,7 +899,7 @@ func (s *SqlGroupStore) ChannelMembersToRemove() ([]*model.ChannelMember, *model return channelMembers, nil } -func (s *SqlGroupStore) groupsBySyncableBaseQuery(st model.GroupSyncableType, t selectType, syncableID string, opts model.GroupSearchOpts) squirrel.SelectBuilder { +func (s *SqlGroupStore) groupsBySyncableBaseQuery(st model.GroupSyncableType, t selectType, syncableID string, opts model.GroupSearchOpts) sq.SelectBuilder { selectStrs := map[selectType]string{ selectGroups: "ug.*", selectCountGroups: "COUNT(*)", @@ -1051,7 +1050,7 @@ func (s *SqlGroupStore) GetGroups(page, perPage int, opts model.GroupSearchOpts) return groups, nil } -func (s *SqlGroupStore) teamMembersMinusGroupMembersQuery(teamID string, groupIDs []string, isCount bool) squirrel.SelectBuilder { +func (s *SqlGroupStore) teamMembersMinusGroupMembersQuery(teamID string, groupIDs []string, isCount bool) sq.SelectBuilder { var selectStr string if isCount { @@ -1129,7 +1128,7 @@ func (s *SqlGroupStore) CountTeamMembersMinusGroupMembers(teamID string, groupID return count, nil } -func (s *SqlGroupStore) channelMembersMinusGroupMembersQuery(channelID string, groupIDs []string, isCount bool) squirrel.SelectBuilder { +func (s *SqlGroupStore) channelMembersMinusGroupMembersQuery(channelID string, groupIDs []string, isCount bool) sq.SelectBuilder { var selectStr string if isCount { diff --git a/store/store.go b/store/store.go index b1e719d805..29bb79d7b3 100644 --- a/store/store.go +++ b/store/store.go @@ -149,6 +149,7 @@ type ChannelStore interface { GetDeleted(team_id string, offset int, limit int) (*model.ChannelList, *model.AppError) GetChannels(teamId string, userId string, includeDeleted bool) (*model.ChannelList, *model.AppError) GetAllChannels(page, perPage int, opts ChannelSearchOpts) (*model.ChannelListWithTeamData, *model.AppError) + GetAllChannelsCount(opts ChannelSearchOpts) (int64, *model.AppError) GetMoreChannels(teamId string, userId string, offset int, limit int) (*model.ChannelList, *model.AppError) GetPublicChannelsForTeam(teamId string, offset int, limit int) (*model.ChannelList, *model.AppError) GetPublicChannelsByIdsForTeam(teamId string, channelIds []string) (*model.ChannelList, *model.AppError) diff --git a/store/storetest/channel_store.go b/store/storetest/channel_store.go index f43c2af956..d2d1eb20ef 100644 --- a/store/storetest/channel_store.go +++ b/store/storetest/channel_store.go @@ -1160,6 +1160,9 @@ func testChannelStoreGetAllChannels(t *testing.T, ss store.Store, s SqlSupplier) assert.Equal(t, (*list)[1].Id, c3.Id) assert.Equal(t, (*list)[1].TeamDisplayName, "Name2") + count1, err := ss.Channel().GetAllChannelsCount(store.ChannelSearchOpts{}) + require.Nil(t, err) + list, err = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{IncludeDeleted: true}) require.Nil(t, err) assert.Len(t, *list, 3) @@ -1168,6 +1171,12 @@ func testChannelStoreGetAllChannels(t *testing.T, ss store.Store, s SqlSupplier) assert.Equal(t, (*list)[1].Id, c2.Id) assert.Equal(t, (*list)[2].Id, c3.Id) + count2, err := ss.Channel().GetAllChannelsCount(store.ChannelSearchOpts{IncludeDeleted: true}) + require.Nil(t, err) + require.True(t, func() bool { + return count2 > count1 + }()) + list, err = ss.Channel().GetAllChannels(0, 1, store.ChannelSearchOpts{IncludeDeleted: true}) require.Nil(t, err) assert.Len(t, *list, 1) diff --git a/store/storetest/mocks/ChannelStore.go b/store/storetest/mocks/ChannelStore.go index 8cf2a81b00..deb3ae23e7 100644 --- a/store/storetest/mocks/ChannelStore.go +++ b/store/storetest/mocks/ChannelStore.go @@ -278,6 +278,29 @@ func (_m *ChannelStore) GetAllChannels(page int, perPage int, opts store.Channel return r0, r1 } +// GetAllChannelsCount provides a mock function with given fields: opts +func (_m *ChannelStore) GetAllChannelsCount(opts store.ChannelSearchOpts) (int64, *model.AppError) { + ret := _m.Called(opts) + + var r0 int64 + if rf, ok := ret.Get(0).(func(store.ChannelSearchOpts) int64); ok { + r0 = rf(opts) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(store.ChannelSearchOpts) *model.AppError); ok { + r1 = rf(opts) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 +} + // GetAllChannelsForExportAfter provides a mock function with given fields: limit, afterId func (_m *ChannelStore) GetAllChannelsForExportAfter(limit int, afterId string) ([]*model.ChannelForExport, *model.AppError) { ret := _m.Called(limit, afterId)