mirror of
https://github.com/grafana/grafana.git
synced 2025-01-09 23:53:25 -06:00
Auth: Remove oAuthProviders from Social service (#78732)
* Remove oauthProviders from social svc * Add EnabledFn to supportbundles.Collector
This commit is contained in:
parent
86311e3a33
commit
79577e4929
@ -38,9 +38,8 @@ const (
|
||||
type SocialService struct {
|
||||
cfg *setting.Cfg
|
||||
|
||||
socialMap map[string]SocialConnector
|
||||
oAuthProvider map[string]*OAuthInfo
|
||||
log log.Logger
|
||||
socialMap map[string]SocialConnector
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
type OAuthInfo struct {
|
||||
@ -85,10 +84,9 @@ func ProvideService(cfg *setting.Cfg,
|
||||
cache remotecache.CacheStorage,
|
||||
) *SocialService {
|
||||
ss := &SocialService{
|
||||
cfg: cfg,
|
||||
oAuthProvider: make(map[string]*OAuthInfo),
|
||||
socialMap: make(map[string]SocialConnector),
|
||||
log: log.New("login.social"),
|
||||
cfg: cfg,
|
||||
socialMap: make(map[string]SocialConnector),
|
||||
log: log.New("login.social"),
|
||||
}
|
||||
|
||||
usageStats.RegisterMetricsFunc(ss.getUsageStats)
|
||||
@ -117,7 +115,6 @@ func ProvideService(cfg *setting.Cfg,
|
||||
}
|
||||
|
||||
ss.socialMap[name] = conn
|
||||
ss.oAuthProvider[name] = ss.socialMap[name].GetOAuthInfo()
|
||||
}
|
||||
|
||||
ss.registerSupportBundleCollectors(bundleRegistry)
|
||||
@ -322,20 +319,8 @@ func getRoleFromSearch(role string) (org.RoleType, bool) {
|
||||
func (ss *SocialService) GetOAuthProviders() map[string]bool {
|
||||
result := map[string]bool{}
|
||||
|
||||
if ss.cfg == nil || ss.cfg.Raw == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, name := range allOauthes {
|
||||
if name == "grafananet" {
|
||||
name = grafanaCom
|
||||
}
|
||||
|
||||
sec := ss.cfg.Raw.Section("auth." + name)
|
||||
if sec == nil {
|
||||
continue
|
||||
}
|
||||
result[name] = sec.Key("enabled").MustBool()
|
||||
for name, conn := range ss.socialMap {
|
||||
result[name] = conn.GetOAuthInfo().Enabled
|
||||
}
|
||||
|
||||
return result
|
||||
@ -344,11 +329,16 @@ func (ss *SocialService) GetOAuthProviders() map[string]bool {
|
||||
func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
|
||||
name = strings.TrimPrefix(name, "oauth_")
|
||||
info, ok := ss.oAuthProvider[name]
|
||||
provider, ok := ss.socialMap[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not find %q in OAuth Settings", name)
|
||||
}
|
||||
|
||||
info := provider.GetOAuthInfo()
|
||||
if !info.Enabled {
|
||||
return nil, fmt.Errorf("oauth provider %q is not enabled", name)
|
||||
}
|
||||
|
||||
// handle call back
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@ -404,11 +394,23 @@ func (ss *SocialService) GetConnector(name string) (SocialConnector, error) {
|
||||
}
|
||||
|
||||
func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo {
|
||||
return ss.oAuthProvider[name]
|
||||
connector, ok := ss.socialMap[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return connector.GetOAuthInfo()
|
||||
}
|
||||
|
||||
// GetOAuthInfoProviders returns enabled OAuth providers
|
||||
func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo {
|
||||
return ss.oAuthProvider
|
||||
result := map[string]*OAuthInfo{}
|
||||
for name, connector := range ss.socialMap {
|
||||
info := connector.GetOAuthInfo()
|
||||
if info.Enabled {
|
||||
result[name] = info
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (ss *SocialService) getUsageStats(ctx context.Context) (map[string]any, error) {
|
||||
|
@ -13,19 +13,20 @@ import (
|
||||
)
|
||||
|
||||
func (ss *SocialService) registerSupportBundleCollectors(bundleRegistry supportbundles.Service) {
|
||||
for name := range ss.oAuthProvider {
|
||||
for name, connector := range ss.socialMap {
|
||||
bundleRegistry.RegisterSupportItemCollector(supportbundles.Collector{
|
||||
UID: "oauth-" + name,
|
||||
DisplayName: "OAuth " + strings.Title(strings.ReplaceAll(name, "_", " ")),
|
||||
Description: "OAuth configuration and healthchecks for " + name,
|
||||
IncludedByDefault: false,
|
||||
Default: false,
|
||||
Fn: ss.supportBundleCollectorFn(name, ss.socialMap[name], ss.oAuthProvider[name]),
|
||||
EnabledFn: func() bool { return connector.GetOAuthInfo().Enabled },
|
||||
Fn: ss.supportBundleCollectorFn(name, connector),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector, oinfo *OAuthInfo) func(context.Context) (*supportbundles.SupportItem, error) {
|
||||
func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector) func(context.Context) (*supportbundles.SupportItem, error) {
|
||||
return func(ctx context.Context) (*supportbundles.SupportItem, error) {
|
||||
bWriter := bytes.NewBuffer(nil)
|
||||
|
||||
@ -37,6 +38,8 @@ func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnecto
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oinfo := sc.GetOAuthInfo()
|
||||
|
||||
bWriter.WriteString("```toml\n")
|
||||
errM := toml.NewEncoder(bWriter).Encode(oinfo)
|
||||
if errM != nil {
|
||||
|
@ -46,6 +46,8 @@ type Collector struct {
|
||||
Default bool `json:"default"`
|
||||
// Fn is the function that collects the support item.
|
||||
Fn CollectorFunc `json:"-"`
|
||||
// EnabledFn is a function that determines if the collector is enabled. If nil, the collector is always enabled.
|
||||
EnabledFn func() bool `json:"-"`
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
|
@ -74,9 +74,15 @@ func (s *Service) bundle(ctx context.Context, collectors []string, uid string) (
|
||||
files := map[string][]byte{}
|
||||
|
||||
for _, collector := range s.bundleRegistry.Collectors() {
|
||||
if !lookup[collector.UID] && !collector.IncludedByDefault {
|
||||
collectorEnabled := true
|
||||
if collector.EnabledFn != nil {
|
||||
collectorEnabled = collector.EnabledFn()
|
||||
}
|
||||
|
||||
if !(lookup[collector.UID] || collector.IncludedByDefault) || !collectorEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
item, err := collector.Fn(ctx)
|
||||
if err != nil {
|
||||
s.log.Warn("Failed to collect support bundle item", "error", err, "collector", collector.UID)
|
||||
|
@ -38,12 +38,16 @@ func TestService_bundleCreate(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
|
||||
collector := basicCollector(cfg)
|
||||
disabledCollector := settingsCollector(setting.ProvideProvider(cfg))
|
||||
disabledCollector.EnabledFn = func() bool { return false }
|
||||
|
||||
s.bundleRegistry.RegisterSupportItemCollector(collector)
|
||||
s.bundleRegistry.RegisterSupportItemCollector(disabledCollector)
|
||||
|
||||
createdBundle, err := s.store.Create(context.Background(), &user.SignedInUser{UserID: 1, Login: "bob"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.startBundleWork(context.Background(), []string{collector.UID}, createdBundle.UID)
|
||||
s.startBundleWork(context.Background(), []string{collector.UID, disabledCollector.UID}, createdBundle.UID)
|
||||
|
||||
bundle, err := s.get(context.Background(), createdBundle.UID)
|
||||
require.NoError(t, err)
|
||||
|
Loading…
Reference in New Issue
Block a user