Auth: Remove oAuthProviders from Social service (#78732)

* Remove oauthProviders from social svc

* Add EnabledFn to supportbundles.Collector
This commit is contained in:
Misi 2023-11-30 09:30:35 +01:00 committed by GitHub
parent 86311e3a33
commit 79577e4929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 30 deletions

View File

@ -38,9 +38,8 @@ const (
type SocialService struct {
cfg *setting.Cfg
socialMap map[string]SocialConnector
oAuthProvider map[string]*OAuthInfo
log log.Logger
socialMap map[string]SocialConnector
log log.Logger
}
type OAuthInfo struct {
@ -85,10 +84,9 @@ func ProvideService(cfg *setting.Cfg,
cache remotecache.CacheStorage,
) *SocialService {
ss := &SocialService{
cfg: cfg,
oAuthProvider: make(map[string]*OAuthInfo),
socialMap: make(map[string]SocialConnector),
log: log.New("login.social"),
cfg: cfg,
socialMap: make(map[string]SocialConnector),
log: log.New("login.social"),
}
usageStats.RegisterMetricsFunc(ss.getUsageStats)
@ -117,7 +115,6 @@ func ProvideService(cfg *setting.Cfg,
}
ss.socialMap[name] = conn
ss.oAuthProvider[name] = ss.socialMap[name].GetOAuthInfo()
}
ss.registerSupportBundleCollectors(bundleRegistry)
@ -322,20 +319,8 @@ func getRoleFromSearch(role string) (org.RoleType, bool) {
func (ss *SocialService) GetOAuthProviders() map[string]bool {
result := map[string]bool{}
if ss.cfg == nil || ss.cfg.Raw == nil {
return result
}
for _, name := range allOauthes {
if name == "grafananet" {
name = grafanaCom
}
sec := ss.cfg.Raw.Section("auth." + name)
if sec == nil {
continue
}
result[name] = sec.Key("enabled").MustBool()
for name, conn := range ss.socialMap {
result[name] = conn.GetOAuthInfo().Enabled
}
return result
@ -344,11 +329,16 @@ func (ss *SocialService) GetOAuthProviders() map[string]bool {
func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
name = strings.TrimPrefix(name, "oauth_")
info, ok := ss.oAuthProvider[name]
provider, ok := ss.socialMap[name]
if !ok {
return nil, fmt.Errorf("could not find %q in OAuth Settings", name)
}
info := provider.GetOAuthInfo()
if !info.Enabled {
return nil, fmt.Errorf("oauth provider %q is not enabled", name)
}
// handle call back
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -404,11 +394,23 @@ func (ss *SocialService) GetConnector(name string) (SocialConnector, error) {
}
func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo {
return ss.oAuthProvider[name]
connector, ok := ss.socialMap[name]
if !ok {
return nil
}
return connector.GetOAuthInfo()
}
// GetOAuthInfoProviders returns enabled OAuth providers
func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo {
return ss.oAuthProvider
result := map[string]*OAuthInfo{}
for name, connector := range ss.socialMap {
info := connector.GetOAuthInfo()
if info.Enabled {
result[name] = info
}
}
return result
}
func (ss *SocialService) getUsageStats(ctx context.Context) (map[string]any, error) {

View File

@ -13,19 +13,20 @@ import (
)
func (ss *SocialService) registerSupportBundleCollectors(bundleRegistry supportbundles.Service) {
for name := range ss.oAuthProvider {
for name, connector := range ss.socialMap {
bundleRegistry.RegisterSupportItemCollector(supportbundles.Collector{
UID: "oauth-" + name,
DisplayName: "OAuth " + strings.Title(strings.ReplaceAll(name, "_", " ")),
Description: "OAuth configuration and healthchecks for " + name,
IncludedByDefault: false,
Default: false,
Fn: ss.supportBundleCollectorFn(name, ss.socialMap[name], ss.oAuthProvider[name]),
EnabledFn: func() bool { return connector.GetOAuthInfo().Enabled },
Fn: ss.supportBundleCollectorFn(name, connector),
})
}
}
func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector, oinfo *OAuthInfo) func(context.Context) (*supportbundles.SupportItem, error) {
func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector) func(context.Context) (*supportbundles.SupportItem, error) {
return func(ctx context.Context) (*supportbundles.SupportItem, error) {
bWriter := bytes.NewBuffer(nil)
@ -37,6 +38,8 @@ func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnecto
return nil, err
}
oinfo := sc.GetOAuthInfo()
bWriter.WriteString("```toml\n")
errM := toml.NewEncoder(bWriter).Encode(oinfo)
if errM != nil {

View File

@ -46,6 +46,8 @@ type Collector struct {
Default bool `json:"default"`
// Fn is the function that collects the support item.
Fn CollectorFunc `json:"-"`
// EnabledFn is a function that determines if the collector is enabled. If nil, the collector is always enabled.
EnabledFn func() bool `json:"-"`
}
type Service interface {

View File

@ -74,9 +74,15 @@ func (s *Service) bundle(ctx context.Context, collectors []string, uid string) (
files := map[string][]byte{}
for _, collector := range s.bundleRegistry.Collectors() {
if !lookup[collector.UID] && !collector.IncludedByDefault {
collectorEnabled := true
if collector.EnabledFn != nil {
collectorEnabled = collector.EnabledFn()
}
if !(lookup[collector.UID] || collector.IncludedByDefault) || !collectorEnabled {
continue
}
item, err := collector.Fn(ctx)
if err != nil {
s.log.Warn("Failed to collect support bundle item", "error", err, "collector", collector.UID)

View File

@ -38,12 +38,16 @@ func TestService_bundleCreate(t *testing.T) {
cfg := setting.NewCfg()
collector := basicCollector(cfg)
disabledCollector := settingsCollector(setting.ProvideProvider(cfg))
disabledCollector.EnabledFn = func() bool { return false }
s.bundleRegistry.RegisterSupportItemCollector(collector)
s.bundleRegistry.RegisterSupportItemCollector(disabledCollector)
createdBundle, err := s.store.Create(context.Background(), &user.SignedInUser{UserID: 1, Login: "bob"})
require.NoError(t, err)
s.startBundleWork(context.Background(), []string{collector.UID}, createdBundle.UID)
s.startBundleWork(context.Background(), []string{collector.UID, disabledCollector.UID}, createdBundle.UID)
bundle, err := s.get(context.Background(), createdBundle.UID)
require.NoError(t, err)