SSO: add configurableProviders list to SSO service (#86622)

* add configurableProviders list to sso service

* address feedback
This commit is contained in:
Mihai Doarna 2024-04-23 10:00:43 +03:00 committed by GitHub
parent 51dcd1d9fd
commit 4d9e35ba57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 52 deletions

View File

@ -8,6 +8,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/prometheus/client_golang/prometheus"
"github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
@ -24,7 +26,6 @@ import (
"github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/services/ssosettings/models"
"github.com/grafana/grafana/pkg/services/ssosettings/strategies" "github.com/grafana/grafana/pkg/services/ssosettings/strategies"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/prometheus/client_golang/prometheus"
) )
var _ ssosettings.Service = (*Service)(nil) var _ ssosettings.Service = (*Service)(nil)
@ -37,9 +38,10 @@ type Service struct {
secrets secrets.Service secrets secrets.Service
metrics *metrics metrics *metrics
fbStrategies []ssosettings.FallbackStrategy fbStrategies []ssosettings.FallbackStrategy
providersList []string providersList []string
reloadables map[string]ssosettings.Reloadable configurableProviders map[string]bool
reloadables map[string]ssosettings.Reloadable
} }
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl, func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
@ -50,27 +52,34 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
strategies.NewOAuthStrategy(cfg), strategies.NewOAuthStrategy(cfg),
} }
configurableProviders := make(map[string]bool)
for provider, enabled := range cfg.SSOSettingsConfigurableProviders {
configurableProviders[provider] = enabled
}
providersList := ssosettings.AllOAuthProviders providersList := ssosettings.AllOAuthProviders
if licensing.FeatureEnabled(social.SAMLProviderName) { if licensing.FeatureEnabled(social.SAMLProviderName) {
fbStrategies = append(fbStrategies, strategies.NewSAMLStrategy(settingsProvider)) fbStrategies = append(fbStrategies, strategies.NewSAMLStrategy(settingsProvider))
if cfg.SSOSettingsConfigurableProviders[social.SAMLProviderName] { if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsSAML) {
providersList = append(providersList, social.SAMLProviderName) providersList = append(providersList, social.SAMLProviderName)
configurableProviders[social.SAMLProviderName] = true
} }
} }
store := database.ProvideStore(sqlStore) store := database.ProvideStore(sqlStore)
svc := &Service{ svc := &Service{
logger: log.New("ssosettings.service"), logger: log.New("ssosettings.service"),
cfg: cfg, cfg: cfg,
store: store, store: store,
ac: ac, ac: ac,
fbStrategies: fbStrategies, fbStrategies: fbStrategies,
secrets: secrets, secrets: secrets,
metrics: newMetrics(registerer), metrics: newMetrics(registerer),
providersList: providersList, providersList: providersList,
reloadables: make(map[string]ssosettings.Reloadable), configurableProviders: configurableProviders,
reloadables: make(map[string]ssosettings.Reloadable),
} }
usageStats.RegisterMetricsFunc(svc.getUsageStats) usageStats.RegisterMetricsFunc(svc.getUsageStats)
@ -160,7 +169,7 @@ func (s *Service) ListWithRedactedSecrets(ctx context.Context) ([]*models.SSOSet
return nil, err return nil, err
} }
configurableSettings := make([]*models.SSOSettings, 0, len(s.cfg.SSOSettingsConfigurableProviders)) configurableSettings := make([]*models.SSOSettings, 0, len(s.configurableProviders))
for _, provider := range storeSettings { for _, provider := range storeSettings {
if s.isProviderConfigurable(provider.Provider) { if s.isProviderConfigurable(provider.Provider) {
configurableSettings = append(configurableSettings, provider) configurableSettings = append(configurableSettings, provider)
@ -431,8 +440,8 @@ func (s *Service) decryptSecrets(ctx context.Context, settings map[string]any) (
} }
func (s *Service) isProviderConfigurable(provider string) bool { func (s *Service) isProviderConfigurable(provider string) bool {
_, ok := s.cfg.SSOSettingsConfigurableProviders[provider] enabled, ok := s.configurableProviders[provider]
return ok return ok && enabled
} }
// removeSecrets removes all the secrets from the map and replaces them with a redacted password // removeSecrets removes all the secrets from the map and replaces them with a redacted password

View File

@ -241,7 +241,7 @@ func TestService_GetForProvider(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
if tc.setup != nil { if tc.setup != nil {
tc.setup(env) tc.setup(env)
} }
@ -350,7 +350,7 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
if tc.setup != nil { if tc.setup != nil {
tc.setup(env) tc.setup(env)
} }
@ -501,7 +501,7 @@ func TestService_List(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
if tc.setup != nil { if tc.setup != nil {
tc.setup(env) tc.setup(env)
} }
@ -803,7 +803,7 @@ func TestService_ListWithRedactedSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
if tc.setup != nil { if tc.setup != nil {
tc.setup(env) tc.setup(env)
} }
@ -827,7 +827,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("successfully upsert SSO settings", func(t *testing.T) { t.Run("successfully upsert SSO settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -890,7 +890,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if provider is not configurable", func(t *testing.T) { t.Run("returns error if provider is not configurable", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.GrafanaComProviderName provider := social.GrafanaComProviderName
settings := &models.SSOSettings{ settings := &models.SSOSettings{
@ -913,7 +913,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if provider was not found in reloadables", func(t *testing.T) { t.Run("returns error if provider was not found in reloadables", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := &models.SSOSettings{ settings := &models.SSOSettings{
@ -937,7 +937,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if validation fails", func(t *testing.T) { t.Run("returns error if validation fails", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -961,7 +961,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) { t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
settings := &models.SSOSettings{ settings := &models.SSOSettings{
Provider: social.AzureADProviderName, Provider: social.AzureADProviderName,
@ -982,7 +982,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if secrets encryption failed", func(t *testing.T) { t.Run("returns error if secrets encryption failed", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.OktaProviderName provider := social.OktaProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -1007,7 +1007,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("should not update the current secret if the secret has not been updated", func(t *testing.T) { t.Run("should not update the current secret if the secret has not been updated", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -1044,7 +1044,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if store failed to upsert settings", func(t *testing.T) { t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -1076,7 +1076,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) { t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := models.SSOSettings{
@ -1109,7 +1109,7 @@ func TestService_Delete(t *testing.T) {
t.Run("successfully delete SSO settings", func(t *testing.T) { t.Run("successfully delete SSO settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -1147,7 +1147,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return error if SSO setting was not found for the specified provider", func(t *testing.T) { t.Run("return error if SSO setting was not found for the specified provider", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t) reloadable := ssosettingstests.NewMockReloadable(t)
@ -1163,7 +1163,7 @@ func TestService_Delete(t *testing.T) {
t.Run("should not delete the SSO settings if the provider is not configurable", func(t *testing.T) { t.Run("should not delete the SSO settings if the provider is not configurable", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
env.cfg.SSOSettingsConfigurableProviders = map[string]bool{social.AzureADProviderName: true} env.cfg.SSOSettingsConfigurableProviders = map[string]bool{social.AzureADProviderName: true}
provider := social.GrafanaComProviderName provider := social.GrafanaComProviderName
@ -1176,7 +1176,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return error when store fails to delete the SSO settings for the specified provider", func(t *testing.T) { t.Run("return error when store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
env.store.ExpectedError = errors.New("delete sso settings failed") env.store.ExpectedError = errors.New("delete sso settings failed")
@ -1189,7 +1189,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return successfully when the deletion was successful but reloading the settings fail", func(t *testing.T) { t.Run("return successfully when the deletion was successful but reloading the settings fail", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t) reloadable := ssosettingstests.NewMockReloadable(t)
@ -1211,7 +1211,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("successfully reload settings", func(t *testing.T) { t.Run("successfully reload settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
settingsList := []*models.SSOSettings{ settingsList := []*models.SSOSettings{
{ {
@ -1251,7 +1251,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("successfully reload settings when some providers have empty settings", func(t *testing.T) { t.Run("successfully reload settings when some providers have empty settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
settingsList := []*models.SSOSettings{ settingsList := []*models.SSOSettings{
{ {
@ -1281,7 +1281,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("failed fetching the SSO settings", func(t *testing.T) { t.Run("failed fetching the SSO settings", func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
provider := "github" provider := "github"
@ -1382,7 +1382,7 @@ func TestService_decryptSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, false, false, nil) env := setupTestEnv(t, false, false, false)
if tc.setup != nil { if tc.setup != nil {
tc.setup(env) tc.setup(env)
@ -1407,12 +1407,12 @@ func Test_ProviderService(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
isLicenseEnabled bool isLicenseEnabled bool
configurableProviders map[string]bool samlEnabled bool
expectedProvidersList []string expectedProvidersList []string
strategiesLength int strategiesLength int
}{ }{
{ {
name: "should return all OAuth providers but not saml because the licensing feature is not enabled", name: "should return all OAuth providers but not SAML because the licensing feature is not enabled",
isLicenseEnabled: false, isLicenseEnabled: false,
expectedProvidersList: []string{ expectedProvidersList: []string{
"github", "github",
@ -1426,7 +1426,7 @@ func Test_ProviderService(t *testing.T) {
strategiesLength: 1, strategiesLength: 1,
}, },
{ {
name: "should return all fallback strategies and it should return all OAuth providers but not saml because the licensing feature is enabled but the configurable provider is not setup", name: "should return all fallback strategies and it should return all OAuth providers but not SAML because the licensing feature is enabled but the configurable provider is not setup",
isLicenseEnabled: true, isLicenseEnabled: true,
expectedProvidersList: []string{ expectedProvidersList: []string{
"github", "github",
@ -1440,9 +1440,9 @@ func Test_ProviderService(t *testing.T) {
strategiesLength: 2, strategiesLength: 2,
}, },
{ {
name: "should return all fallback strategies and it should return all OAuth providers and saml because the licensing feature is enabled and the provider is setup", name: "should return all fallback strategies and it should return all OAuth providers and SAML because the licensing feature is enabled and the provider is setup",
isLicenseEnabled: true, isLicenseEnabled: true,
configurableProviders: map[string]bool{"saml": true}, samlEnabled: true,
expectedProvidersList: []string{ expectedProvidersList: []string{
"github", "github",
"gitlab", "gitlab",
@ -1461,7 +1461,7 @@ func Test_ProviderService(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.configurableProviders) env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.samlEnabled)
require.Equal(t, tc.expectedProvidersList, env.service.providersList) require.Equal(t, tc.expectedProvidersList, env.service.providersList)
require.Equal(t, tc.strategiesLength, len(env.service.fbStrategies)) require.Equal(t, tc.strategiesLength, len(env.service.fbStrategies))
@ -1469,7 +1469,7 @@ func Test_ProviderService(t *testing.T) {
} }
} }
func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool, extraConfigurableProviders map[string]bool) testEnv { func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies, samlEnabled bool) testEnv {
t.Helper() t.Helper()
store := ssosettingstests.NewFakeStore() store := ssosettingstests.NewFakeStore()
@ -1491,10 +1491,6 @@ func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool
"gitlab": true, "gitlab": true,
} }
for k, v := range extraConfigurableProviders {
configurableProviders[k] = v
}
cfg := &setting.Cfg{ cfg := &setting.Cfg{
SSOSettingsConfigurableProviders: configurableProviders, SSOSettingsConfigurableProviders: configurableProviders,
Raw: iniFile, Raw: iniFile,
@ -1503,12 +1499,17 @@ func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool
licensing := licensingtest.NewFakeLicensing() licensing := licensingtest.NewFakeLicensing()
licensing.On("FeatureEnabled", "saml").Return(isLicensingEnabled) licensing.On("FeatureEnabled", "saml").Return(isLicensingEnabled)
featureManager := featuremgmt.WithManager()
if samlEnabled {
featureManager = featuremgmt.WithManager(featuremgmt.FlagSsoSettingsSAML)
}
svc := ProvideService( svc := ProvideService(
cfg, cfg,
&dbtest.FakeDB{}, &dbtest.FakeDB{},
accessControl, accessControl,
routing.NewRouteRegister(), routing.NewRouteRegister(),
featuremgmt.WithManager(nil), featureManager,
secretsFakes.NewMockService(t), secretsFakes.NewMockService(t),
&usagestats.UsageStatsMock{}, &usagestats.UsageStatsMock{},
prometheus.NewRegistry(), prometheus.NewRegistry(),