Auth: Reload OAuth provider after deletion of the current settings (#81374)

* Reload after deletion of the current settings

* Add grafana_ssosettings_setting_reload_failure_total counter

* Returns successfully if data reload failed
This commit is contained in:
Misi 2024-01-29 12:02:04 +01:00 committed by GitHub
parent a3fda08d4e
commit 7e96a2be56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 117 additions and 13 deletions

View File

@ -73,7 +73,7 @@ func TestSocialService_ProvideService(t *testing.T) {
accessControl := acimpl.ProvideAccessControl(cfg)
sqlStore := db.InitTestDB(t)
ssoSettingsSvc := ssosettingsimpl.ProvideService(cfg, sqlStore, accessControl, routing.NewRouteRegister(), featuremgmt.WithFeatures(), secrets, &usagestats.UsageStatsMock{})
ssoSettingsSvc := ssosettingsimpl.ProvideService(cfg, sqlStore, accessControl, routing.NewRouteRegister(), featuremgmt.WithFeatures(), secrets, &usagestats.UsageStatsMock{}, nil)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {

View File

@ -0,0 +1,31 @@
package ssosettingsimpl
import "github.com/prometheus/client_golang/prometheus"
const (
metricsNamespace = "grafana"
metricsSubSystem = "ssosettings"
)
type metrics struct {
reloadFailures *prometheus.CounterVec
}
func newMetrics(reg prometheus.Registerer) *metrics {
m := &metrics{
reloadFailures: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: metricsSubSystem,
Name: "setting_reload_failures_total",
Help: "Number of SSO Setting reload failures.",
}, []string{"provider"}),
}
if reg != nil {
reg.MustRegister(
m.reloadFailures,
)
}
return m
}

View File

@ -21,6 +21,7 @@ import (
"github.com/grafana/grafana/pkg/services/ssosettings/models"
"github.com/grafana/grafana/pkg/services/ssosettings/strategies"
"github.com/grafana/grafana/pkg/setting"
"github.com/prometheus/client_golang/prometheus"
)
var _ ssosettings.Service = (*Service)(nil)
@ -31,6 +32,7 @@ type Service struct {
store ssosettings.Store
ac ac.AccessControl
secrets secrets.Service
metrics *metrics
fbStrategies []ssosettings.FallbackStrategy
reloadables map[string]ssosettings.Reloadable
@ -38,7 +40,7 @@ type Service struct {
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
routeRegister routing.RouteRegister, features featuremgmt.FeatureToggles,
secrets secrets.Service, usageStats usagestats.Service) *Service {
secrets secrets.Service, usageStats usagestats.Service, registerer prometheus.Registerer) *Service {
strategies := []ssosettings.FallbackStrategy{
strategies.NewOAuthStrategy(cfg),
// register other strategies here, for example SAML
@ -53,6 +55,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
ac: ac,
fbStrategies: strategies,
secrets: secrets,
metrics: newMetrics(registerer),
reloadables: make(map[string]ssosettings.Reloadable),
}
@ -193,13 +196,9 @@ func (s *Service) Upsert(ctx context.Context, settings *models.SSOSettings) erro
return err
}
go func() {
settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets)
err = social.Reload(context.Background(), *settings)
if err != nil {
s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
}
}()
settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets)
go s.reload(social, settings.Provider, *settings)
return nil
}
@ -212,7 +211,34 @@ func (s *Service) Delete(ctx context.Context, provider string) error {
if !s.isProviderConfigurable(provider) {
return ssosettings.ErrNotConfigurable
}
return s.store.Delete(ctx, provider)
social, ok := s.reloadables[provider]
if !ok {
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", provider)
}
err := s.store.Delete(ctx, provider)
if err != nil {
return err
}
currentSettings, err := s.GetForProvider(ctx, provider)
if err != nil {
s.logger.Error("failed to get current settings, skipping reload", "provider", provider, "error", err)
return nil
}
go s.reload(social, provider, *currentSettings)
return nil
}
func (s *Service) reload(reloadable ssosettings.Reloadable, provider string, currentSettings models.SSOSettings) {
err := reloadable.Reload(context.Background(), currentSettings)
if err != nil {
s.metrics.reloadFailures.WithLabelValues(provider).Inc()
s.logger.Error("failed to reload the provider", "provider", provider, "error", err)
}
}
func (s *Service) Reload(ctx context.Context, provider string) {
@ -327,6 +353,7 @@ func (s *Service) doReload(ctx context.Context) {
err = connector.Reload(ctx, *setting)
if err != nil {
s.metrics.reloadFailures.WithLabelValues(provider).Inc()
s.logger.Error("failed to reload SSO Settings", "provider", provider, "err", err)
continue
}

View File

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -1007,17 +1008,45 @@ func TestService_Delete(t *testing.T) {
t.Run("successfully delete SSO settings", func(t *testing.T) {
env := setupTestEnv(t)
var wg sync.WaitGroup
wg.Add(1)
provider := social.AzureADProviderName
env.store.ExpectedError = nil
reloadable := ssosettingstests.NewMockReloadable(t)
env.reloadables[provider] = reloadable
env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{
provider: {
"client_id": "client-id",
"client_secret": "client-secret",
"enabled": true,
},
}
reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool {
wg.Done()
return settings.Provider == provider &&
settings.ID == "" &&
maps.Equal(settings.Settings, map[string]any{
"client_id": "client-id",
"client_secret": "client-secret",
"enabled": true,
})
})).Return(nil).Once()
err := env.service.Delete(context.Background(), provider)
require.NoError(t, err)
// wait for the goroutine first to assert the Reload call
wg.Wait()
})
t.Run("SSO settings not found for the specified provider", func(t *testing.T) {
t.Run("return error if SSO setting was not found for the specified provider", func(t *testing.T) {
env := setupTestEnv(t)
provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t)
env.reloadables[provider] = reloadable
env.store.ExpectedError = ssosettings.ErrNotFound
err := env.service.Delete(context.Background(), provider)
@ -1036,7 +1065,7 @@ func TestService_Delete(t *testing.T) {
require.ErrorIs(t, err, ssosettings.ErrNotConfigurable)
})
t.Run("store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
t.Run("return error when store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
env := setupTestEnv(t)
provider := social.AzureADProviderName
@ -1046,6 +1075,22 @@ func TestService_Delete(t *testing.T) {
require.Error(t, err)
require.NotErrorIs(t, err, ssosettings.ErrNotFound)
})
t.Run("return successfully when the deletion was successful but reloading the settings fail", func(t *testing.T) {
env := setupTestEnv(t)
provider := social.AzureADProviderName
reloadable := ssosettingstests.NewMockReloadable(t)
env.reloadables[provider] = reloadable
env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) {
return nil, errors.New("failed to get sso settings")
}
err := env.service.Delete(context.Background(), provider)
require.NoError(t, err)
})
}
func TestService_DoReload(t *testing.T) {
@ -1223,6 +1268,7 @@ func setupTestEnv(t *testing.T) testEnv {
ac: accessControl,
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
reloadables: reloadables,
metrics: newMetrics(prometheus.NewRegistry()),
secrets: secrets,
}