diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index f189f8e5365..31f3e7827eb 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -38,9 +38,8 @@ const ( type SocialService struct { cfg *setting.Cfg - socialMap map[string]SocialConnector - oAuthProvider map[string]*OAuthInfo - log log.Logger + socialMap map[string]SocialConnector + log log.Logger } type OAuthInfo struct { @@ -85,10 +84,9 @@ func ProvideService(cfg *setting.Cfg, cache remotecache.CacheStorage, ) *SocialService { ss := &SocialService{ - cfg: cfg, - oAuthProvider: make(map[string]*OAuthInfo), - socialMap: make(map[string]SocialConnector), - log: log.New("login.social"), + cfg: cfg, + socialMap: make(map[string]SocialConnector), + log: log.New("login.social"), } usageStats.RegisterMetricsFunc(ss.getUsageStats) @@ -117,7 +115,6 @@ func ProvideService(cfg *setting.Cfg, } ss.socialMap[name] = conn - ss.oAuthProvider[name] = ss.socialMap[name].GetOAuthInfo() } ss.registerSupportBundleCollectors(bundleRegistry) @@ -322,20 +319,8 @@ func getRoleFromSearch(role string) (org.RoleType, bool) { func (ss *SocialService) GetOAuthProviders() map[string]bool { result := map[string]bool{} - if ss.cfg == nil || ss.cfg.Raw == nil { - return result - } - - for _, name := range allOauthes { - if name == "grafananet" { - name = grafanaCom - } - - sec := ss.cfg.Raw.Section("auth." + name) - if sec == nil { - continue - } - result[name] = sec.Key("enabled").MustBool() + for name, conn := range ss.socialMap { + result[name] = conn.GetOAuthInfo().Enabled } return result @@ -344,11 +329,16 @@ func (ss *SocialService) GetOAuthProviders() map[string]bool { func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) { // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does name = strings.TrimPrefix(name, "oauth_") - info, ok := ss.oAuthProvider[name] + provider, ok := ss.socialMap[name] if !ok { return nil, fmt.Errorf("could not find %q in OAuth Settings", name) } + info := provider.GetOAuthInfo() + if !info.Enabled { + return nil, fmt.Errorf("oauth provider %q is not enabled", name) + } + // handle call back tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -404,11 +394,23 @@ func (ss *SocialService) GetConnector(name string) (SocialConnector, error) { } func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo { - return ss.oAuthProvider[name] + connector, ok := ss.socialMap[name] + if !ok { + return nil + } + return connector.GetOAuthInfo() } +// GetOAuthInfoProviders returns enabled OAuth providers func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo { - return ss.oAuthProvider + result := map[string]*OAuthInfo{} + for name, connector := range ss.socialMap { + info := connector.GetOAuthInfo() + if info.Enabled { + result[name] = info + } + } + return result } func (ss *SocialService) getUsageStats(ctx context.Context) (map[string]any, error) { diff --git a/pkg/login/social/support_bundle.go b/pkg/login/social/support_bundle.go index 87ee04dd5f4..54dafcb8b26 100644 --- a/pkg/login/social/support_bundle.go +++ b/pkg/login/social/support_bundle.go @@ -13,19 +13,20 @@ import ( ) func (ss *SocialService) registerSupportBundleCollectors(bundleRegistry supportbundles.Service) { - for name := range ss.oAuthProvider { + for name, connector := range ss.socialMap { bundleRegistry.RegisterSupportItemCollector(supportbundles.Collector{ UID: "oauth-" + name, DisplayName: "OAuth " + strings.Title(strings.ReplaceAll(name, "_", " ")), Description: "OAuth configuration and healthchecks for " + name, IncludedByDefault: false, Default: false, - Fn: ss.supportBundleCollectorFn(name, ss.socialMap[name], ss.oAuthProvider[name]), + EnabledFn: func() bool { return connector.GetOAuthInfo().Enabled }, + Fn: ss.supportBundleCollectorFn(name, connector), }) } } -func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector, oinfo *OAuthInfo) func(context.Context) (*supportbundles.SupportItem, error) { +func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector) func(context.Context) (*supportbundles.SupportItem, error) { return func(ctx context.Context) (*supportbundles.SupportItem, error) { bWriter := bytes.NewBuffer(nil) @@ -37,6 +38,8 @@ func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnecto return nil, err } + oinfo := sc.GetOAuthInfo() + bWriter.WriteString("```toml\n") errM := toml.NewEncoder(bWriter).Encode(oinfo) if errM != nil { diff --git a/pkg/services/supportbundles/interface.go b/pkg/services/supportbundles/interface.go index e876412087b..89cff688525 100644 --- a/pkg/services/supportbundles/interface.go +++ b/pkg/services/supportbundles/interface.go @@ -46,6 +46,8 @@ type Collector struct { Default bool `json:"default"` // Fn is the function that collects the support item. Fn CollectorFunc `json:"-"` + // EnabledFn is a function that determines if the collector is enabled. If nil, the collector is always enabled. + EnabledFn func() bool `json:"-"` } type Service interface { diff --git a/pkg/services/supportbundles/supportbundlesimpl/service_bundle.go b/pkg/services/supportbundles/supportbundlesimpl/service_bundle.go index e978c58472c..fddd6488812 100644 --- a/pkg/services/supportbundles/supportbundlesimpl/service_bundle.go +++ b/pkg/services/supportbundles/supportbundlesimpl/service_bundle.go @@ -74,9 +74,15 @@ func (s *Service) bundle(ctx context.Context, collectors []string, uid string) ( files := map[string][]byte{} for _, collector := range s.bundleRegistry.Collectors() { - if !lookup[collector.UID] && !collector.IncludedByDefault { + collectorEnabled := true + if collector.EnabledFn != nil { + collectorEnabled = collector.EnabledFn() + } + + if !(lookup[collector.UID] || collector.IncludedByDefault) || !collectorEnabled { continue } + item, err := collector.Fn(ctx) if err != nil { s.log.Warn("Failed to collect support bundle item", "error", err, "collector", collector.UID) diff --git a/pkg/services/supportbundles/supportbundlesimpl/service_bundle_test.go b/pkg/services/supportbundles/supportbundlesimpl/service_bundle_test.go index 0a840addcae..016e72d5dbb 100644 --- a/pkg/services/supportbundles/supportbundlesimpl/service_bundle_test.go +++ b/pkg/services/supportbundles/supportbundlesimpl/service_bundle_test.go @@ -38,12 +38,16 @@ func TestService_bundleCreate(t *testing.T) { cfg := setting.NewCfg() collector := basicCollector(cfg) + disabledCollector := settingsCollector(setting.ProvideProvider(cfg)) + disabledCollector.EnabledFn = func() bool { return false } + s.bundleRegistry.RegisterSupportItemCollector(collector) + s.bundleRegistry.RegisterSupportItemCollector(disabledCollector) createdBundle, err := s.store.Create(context.Background(), &user.SignedInUser{UserID: 1, Login: "bob"}) require.NoError(t, err) - s.startBundleWork(context.Background(), []string{collector.UID}, createdBundle.UID) + s.startBundleWork(context.Background(), []string{collector.UID, disabledCollector.UID}, createdBundle.UID) bundle, err := s.get(context.Background(), createdBundle.UID) require.NoError(t, err)