diff --git a/app/user.go b/app/user.go index b14e8d86ad..1c2ccb3bfb 100644 --- a/app/user.go +++ b/app/user.go @@ -2134,11 +2134,7 @@ func (a *App) userBelongsToTeams(userId string, teamIds []string) (bool, *model. } func (a *App) userBelongsToChannels(userId string, channelIds []string) (bool, *model.AppError) { - result := <-a.Srv.Store.Channel().UserBelongsToChannels(userId, channelIds) - if result.Err != nil { - return false, result.Err - } - return result.Data.(bool), nil + return a.Srv.Store.Channel().UserBelongsToChannels(userId, channelIds) } func (a *App) GetViewUsersRestrictions(userId string) (*model.ViewUsersRestrictions, *model.AppError) { diff --git a/store/sqlstore/channel_store.go b/store/sqlstore/channel_store.go index 2a9ae4eb38..c0ecfa7d27 100644 --- a/store/sqlstore/channel_store.go +++ b/store/sqlstore/channel_store.go @@ -2634,26 +2634,22 @@ func (s SqlChannelStore) GetChannelsBatchForIndexing(startTime, endTime int64, l return channels, nil } -func (s SqlChannelStore) UserBelongsToChannels(userId string, channelIds []string) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - query := s.getQueryBuilder(). - Select("Count(*)"). - From("ChannelMembers"). - Where(sq.And{ - sq.Eq{"UserId": userId}, - sq.Eq{"ChannelId": channelIds}, - }) +func (s SqlChannelStore) UserBelongsToChannels(userId string, channelIds []string) (bool, *model.AppError) { + query := s.getQueryBuilder(). + Select("Count(*)"). + From("ChannelMembers"). + Where(sq.And{ + sq.Eq{"UserId": userId}, + sq.Eq{"ChannelId": channelIds}, + }) - queryString, args, err := query.ToSql() - if err != nil { - result.Err = model.NewAppError("SqlChannelStore.UserBelongsToChannels", "store.sql_channel.user_belongs_to_channels.app_error", nil, err.Error(), http.StatusInternalServerError) - return - } - c, err := s.GetReplica().SelectInt(queryString, args...) - if err != nil { - result.Err = model.NewAppError("SqlChannelStore.UserBelongsToChannels", "store.sql_channel.user_belongs_to_channels.app_error", nil, err.Error(), http.StatusInternalServerError) - return - } - result.Data = c > 0 - }) + queryString, args, err := query.ToSql() + if err != nil { + return false, model.NewAppError("SqlChannelStore.UserBelongsToChannels", "store.sql_channel.user_belongs_to_channels.app_error", nil, err.Error(), http.StatusInternalServerError) + } + c, err := s.GetReplica().SelectInt(queryString, args...) + if err != nil { + return false, model.NewAppError("SqlChannelStore.UserBelongsToChannels", "store.sql_channel.user_belongs_to_channels.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return c > 0, nil } diff --git a/store/store.go b/store/store.go index 38f81b96c8..cf199d88c0 100644 --- a/store/store.go +++ b/store/store.go @@ -199,7 +199,7 @@ type ChannelStore interface { GetChannelMembersForExport(userId string, teamId string) StoreChannel RemoveAllDeactivatedMembers(channelId string) StoreChannel GetChannelsBatchForIndexing(startTime, endTime int64, limit int) ([]*model.Channel, *model.AppError) - UserBelongsToChannels(userId string, channelIds []string) StoreChannel + UserBelongsToChannels(userId string, channelIds []string) (bool, *model.AppError) } type ChannelMemberHistoryStore interface { diff --git a/store/storetest/mocks/ChannelStore.go b/store/storetest/mocks/ChannelStore.go index 0b2009ad6d..b11fed3d75 100644 --- a/store/storetest/mocks/ChannelStore.go +++ b/store/storetest/mocks/ChannelStore.go @@ -1240,17 +1240,24 @@ func (_m *ChannelStore) UpdateMember(member *model.ChannelMember) store.StoreCha } // UserBelongsToChannels provides a mock function with given fields: userId, channelIds -func (_m *ChannelStore) UserBelongsToChannels(userId string, channelIds []string) store.StoreChannel { +func (_m *ChannelStore) UserBelongsToChannels(userId string, channelIds []string) (bool, *model.AppError) { ret := _m.Called(userId, channelIds) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(string, []string) store.StoreChannel); ok { + var r0 bool + if rf, ok := ret.Get(0).(func(string, []string) bool); ok { r0 = rf(userId, channelIds) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(bool) + } + + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string, []string) *model.AppError); ok { + r1 = rf(userId, channelIds) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) } } - return r0 + return r0, r1 }