SSO: add configurableProviders list to SSO service (#86622)

* add configurableProviders list to sso service

* address feedback
This commit is contained in:
Mihai Doarna 2024-04-23 10:00:43 +03:00 committed by GitHub
parent 51dcd1d9fd
commit 4d9e35ba57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 52 deletions

View File

@ -8,6 +8,8 @@ import (
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
@ -24,7 +26,6 @@ import (
"github.com/grafana/grafana/pkg/services/ssosettings/models"
"github.com/grafana/grafana/pkg/services/ssosettings/strategies"
"github.com/grafana/grafana/pkg/setting"
"github.com/prometheus/client_golang/prometheus"
)
var _ ssosettings.Service = (*Service)(nil)
@ -37,9 +38,10 @@ type Service struct {
secrets secrets.Service
metrics *metrics
fbStrategies []ssosettings.FallbackStrategy
providersList []string
reloadables map[string]ssosettings.Reloadable
fbStrategies []ssosettings.FallbackStrategy
providersList []string
configurableProviders map[string]bool
reloadables map[string]ssosettings.Reloadable
}
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
@ -50,27 +52,34 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
strategies.NewOAuthStrategy(cfg),
}
configurableProviders := make(map[string]bool)
for provider, enabled := range cfg.SSOSettingsConfigurableProviders {
configurableProviders[provider] = enabled
}
providersList := ssosettings.AllOAuthProviders
if licensing.FeatureEnabled(social.SAMLProviderName) {
fbStrategies = append(fbStrategies, strategies.NewSAMLStrategy(settingsProvider))
if cfg.SSOSettingsConfigurableProviders[social.SAMLProviderName] {
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsSAML) {
providersList = append(providersList, social.SAMLProviderName)
configurableProviders[social.SAMLProviderName] = true
}
}
store := database.ProvideStore(sqlStore)
svc := &Service{
logger: log.New("ssosettings.service"),
cfg: cfg,
store: store,
ac: ac,
fbStrategies: fbStrategies,
secrets: secrets,
metrics: newMetrics(registerer),
providersList: providersList,
reloadables: make(map[string]ssosettings.Reloadable),
logger: log.New("ssosettings.service"),
cfg: cfg,
store: store,
ac: ac,
fbStrategies: fbStrategies,
secrets: secrets,
metrics: newMetrics(registerer),
providersList: providersList,
configurableProviders: configurableProviders,
reloadables: make(map[string]ssosettings.Reloadable),
}
usageStats.RegisterMetricsFunc(svc.getUsageStats)
@ -160,7 +169,7 @@ func (s *Service) ListWithRedactedSecrets(ctx context.Context) ([]*models.SSOSet
return nil, err
}
configurableSettings := make([]*models.SSOSettings, 0, len(s.cfg.SSOSettingsConfigurableProviders))
configurableSettings := make([]*models.SSOSettings, 0, len(s.configurableProviders))
for _, provider := range storeSettings {
if s.isProviderConfigurable(provider.Provider) {
configurableSettings = append(configurableSettings, provider)
@ -431,8 +440,8 @@ func (s *Service) decryptSecrets(ctx context.Context, settings map[string]any) (
}
func (s *Service) isProviderConfigurable(provider string) bool {
_, ok := s.cfg.SSOSettingsConfigurableProviders[provider]
return ok
enabled, ok := s.configurableProviders[provider]
return ok && enabled
}
// removeSecrets removes all the secrets from the map and replaces them with a redacted password

View File

@ -241,7 +241,7 @@ func TestService_GetForProvider(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
if tc.setup != nil {
tc.setup(env)
}
@ -350,7 +350,7 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
if tc.setup != nil {
tc.setup(env)
}
@ -501,7 +501,7 @@ func TestService_List(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
if tc.setup != nil {
tc.setup(env)
}
@ -803,7 +803,7 @@ func TestService_ListWithRedactedSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
if tc.setup != nil {
tc.setup(env)
}
@ -827,7 +827,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("successfully upsert SSO settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := models.SSOSettings{
@ -890,7 +890,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if provider is not configurable", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.GrafanaComProviderName
settings := &models.SSOSettings{
@ -913,7 +913,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if provider was not found in reloadables", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := &models.SSOSettings{
@ -937,7 +937,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if validation fails", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := models.SSOSettings{
@ -961,7 +961,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
settings := &models.SSOSettings{
Provider: social.AzureADProviderName,
@ -982,7 +982,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.OktaProviderName
settings := models.SSOSettings{
@ -1007,7 +1007,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("should not update the current secret if the secret has not been updated", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := models.SSOSettings{
@ -1044,7 +1044,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := models.SSOSettings{
@ -1076,7 +1076,7 @@ func TestService_Upsert(t *testing.T) {
t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
settings := models.SSOSettings{
@ -1109,7 +1109,7 @@ func TestService_Delete(t *testing.T) {
t.Run("successfully delete SSO settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
var wg sync.WaitGroup
wg.Add(1)
@ -1147,7 +1147,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return error if SSO setting was not found for the specified provider", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t)
@ -1163,7 +1163,7 @@ func TestService_Delete(t *testing.T) {
t.Run("should not delete the SSO settings if the provider is not configurable", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
env.cfg.SSOSettingsConfigurableProviders = map[string]bool{social.AzureADProviderName: true}
provider := social.GrafanaComProviderName
@ -1176,7 +1176,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return error when store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
env.store.ExpectedError = errors.New("delete sso settings failed")
@ -1189,7 +1189,7 @@ func TestService_Delete(t *testing.T) {
t.Run("return successfully when the deletion was successful but reloading the settings fail", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t)
@ -1211,7 +1211,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("successfully reload settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
settingsList := []*models.SSOSettings{
{
@ -1251,7 +1251,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("successfully reload settings when some providers have empty settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
settingsList := []*models.SSOSettings{
{
@ -1281,7 +1281,7 @@ func TestService_DoReload(t *testing.T) {
t.Run("failed fetching the SSO settings", func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
provider := "github"
@ -1382,7 +1382,7 @@ func TestService_decryptSecrets(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, false, false, nil)
env := setupTestEnv(t, false, false, false)
if tc.setup != nil {
tc.setup(env)
@ -1407,12 +1407,12 @@ func Test_ProviderService(t *testing.T) {
tests := []struct {
name string
isLicenseEnabled bool
configurableProviders map[string]bool
samlEnabled bool
expectedProvidersList []string
strategiesLength int
}{
{
name: "should return all OAuth providers but not saml because the licensing feature is not enabled",
name: "should return all OAuth providers but not SAML because the licensing feature is not enabled",
isLicenseEnabled: false,
expectedProvidersList: []string{
"github",
@ -1426,7 +1426,7 @@ func Test_ProviderService(t *testing.T) {
strategiesLength: 1,
},
{
name: "should return all fallback strategies and it should return all OAuth providers but not saml because the licensing feature is enabled but the configurable provider is not setup",
name: "should return all fallback strategies and it should return all OAuth providers but not SAML because the licensing feature is enabled but the configurable provider is not setup",
isLicenseEnabled: true,
expectedProvidersList: []string{
"github",
@ -1440,9 +1440,9 @@ func Test_ProviderService(t *testing.T) {
strategiesLength: 2,
},
{
name: "should return all fallback strategies and it should return all OAuth providers and saml because the licensing feature is enabled and the provider is setup",
isLicenseEnabled: true,
configurableProviders: map[string]bool{"saml": true},
name: "should return all fallback strategies and it should return all OAuth providers and SAML because the licensing feature is enabled and the provider is setup",
isLicenseEnabled: true,
samlEnabled: true,
expectedProvidersList: []string{
"github",
"gitlab",
@ -1461,7 +1461,7 @@ func Test_ProviderService(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.configurableProviders)
env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.samlEnabled)
require.Equal(t, tc.expectedProvidersList, env.service.providersList)
require.Equal(t, tc.strategiesLength, len(env.service.fbStrategies))
@ -1469,7 +1469,7 @@ func Test_ProviderService(t *testing.T) {
}
}
func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool, extraConfigurableProviders map[string]bool) testEnv {
func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies, samlEnabled bool) testEnv {
t.Helper()
store := ssosettingstests.NewFakeStore()
@ -1491,10 +1491,6 @@ func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool
"gitlab": true,
}
for k, v := range extraConfigurableProviders {
configurableProviders[k] = v
}
cfg := &setting.Cfg{
SSOSettingsConfigurableProviders: configurableProviders,
Raw: iniFile,
@ -1503,12 +1499,17 @@ func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies bool
licensing := licensingtest.NewFakeLicensing()
licensing.On("FeatureEnabled", "saml").Return(isLicensingEnabled)
featureManager := featuremgmt.WithManager()
if samlEnabled {
featureManager = featuremgmt.WithManager(featuremgmt.FlagSsoSettingsSAML)
}
svc := ProvideService(
cfg,
&dbtest.FakeDB{},
accessControl,
routing.NewRouteRegister(),
featuremgmt.WithManager(nil),
featureManager,
secretsFakes.NewMockService(t),
&usagestats.UsageStatsMock{},
prometheus.NewRegistry(),