mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AuthN: Support HA setups with External Service Account management (#78425)
* Lock when creating external service * Add local lock back * Improve function signature * Define lockName separately to make it more explicit * Update pkg/infra/serverlock/serverlock.go Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com> * Update pkg/infra/serverlock/serverlock.go Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com> --------- Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
This commit is contained in:
parent
61553e1693
commit
72759be6ec
@ -2,6 +2,8 @@ package serverlock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
@ -164,6 +166,74 @@ func (sl *ServerLockService) LockExecuteAndRelease(ctx context.Context, actionNa
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetryOpt is a callback function called after each failed lock acquisition try.
|
||||
// It gets the number of tries passed as an arg.
|
||||
type RetryOpt func(int) error
|
||||
|
||||
type LockTimeConfig struct {
|
||||
MaxInterval time.Duration // Duration after which we consider a lock to be dead and overtake it. Make sure this is big enough so that a server cannot acquire the lock while another server is processing.
|
||||
MinWait time.Duration // Minimum time to wait before retrying to acquire the lock.
|
||||
MaxWait time.Duration // Maximum time to wait before retrying to acquire the lock.
|
||||
}
|
||||
|
||||
// LockExecuteAndReleaseWithRetries mimics LockExecuteAndRelease but waits for the lock to be released if it is already taken.
|
||||
func (sl *ServerLockService) LockExecuteAndReleaseWithRetries(ctx context.Context, actionName string, timeConfig LockTimeConfig, fn func(ctx context.Context), retryOpts ...RetryOpt) error {
|
||||
start := time.Now()
|
||||
ctx, span := sl.tracer.Start(ctx, "ServerLockService.LockExecuteAndReleaseWithRetries")
|
||||
span.SetAttributes(attribute.String("serverlock.actionName", actionName))
|
||||
defer span.End()
|
||||
|
||||
ctxLogger := sl.log.FromContext(ctx)
|
||||
ctxLogger.Debug("Start LockExecuteAndReleaseWithRetries", "actionName", actionName)
|
||||
|
||||
lockChecks := 0
|
||||
|
||||
for {
|
||||
lockChecks++
|
||||
err := sl.acquireForRelease(ctx, actionName, timeConfig.MaxInterval)
|
||||
// could not get the lock
|
||||
if err != nil {
|
||||
var lockedErr *ServerLockExistsError
|
||||
if errors.As(err, &lockedErr) {
|
||||
// if the lock is already taken, wait and try again
|
||||
if lockChecks == 1 { // only warn on first lock check
|
||||
ctxLogger.Warn("another instance has the lock, waiting for it to be released", "actionName", actionName)
|
||||
}
|
||||
|
||||
for _, op := range retryOpts {
|
||||
if err := op(lockChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(lockWait(timeConfig.MinWait, timeConfig.MaxWait))
|
||||
continue
|
||||
}
|
||||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// lock was acquired and released successfully
|
||||
break
|
||||
}
|
||||
|
||||
sl.executeFunc(ctx, actionName, fn)
|
||||
|
||||
if err := sl.releaseLock(ctx, actionName); err != nil {
|
||||
span.RecordError(err)
|
||||
ctxLogger.Error("Failed to release the lock", "error", err)
|
||||
}
|
||||
|
||||
ctxLogger.Debug("LockExecuteAndReleaseWithRetries finished", "actionName", actionName, "duration", time.Since(start))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generate a random duration between minWait and maxWait to ensure instances unlock gradually
|
||||
func lockWait(minWait time.Duration, maxWait time.Duration) time.Duration {
|
||||
return time.Duration(rand.Int63n(int64(maxWait-minWait)) + int64(minWait))
|
||||
}
|
||||
|
||||
// acquireForRelease will check if the lock is already on the database, if it is, will check with maxInterval if it is
|
||||
// timeouted. Returns nil error if the lock was acquired correctly
|
||||
func (sl *ServerLockService) acquireForRelease(ctx context.Context, actionName string, maxInterval time.Duration) error {
|
||||
|
@ -2,6 +2,7 @@ package serverlock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -19,11 +20,11 @@ func TestIntegrationServerLock_LockAndExecute(t *testing.T) {
|
||||
atInterval := time.Hour
|
||||
ctx := context.Background()
|
||||
|
||||
//this time `fn` should be executed
|
||||
// this time `fn` should be executed
|
||||
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
|
||||
require.Equal(t, 1, counter)
|
||||
|
||||
//this should not execute `fn`
|
||||
// this should not execute `fn`
|
||||
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
|
||||
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
|
||||
require.Equal(t, 1, counter)
|
||||
@ -62,3 +63,65 @@ func TestIntegrationServerLock_LockExecuteAndRelease(t *testing.T) {
|
||||
|
||||
require.Equal(t, 4, counter)
|
||||
}
|
||||
|
||||
func TestIntegrationServerLock_LockExecuteAndReleaseWithRetries(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
sl := createTestableServerLock(t)
|
||||
|
||||
retries := 0
|
||||
expectedRetries := 10
|
||||
funcRuns := 0
|
||||
fn := func(context.Context) {
|
||||
funcRuns++
|
||||
}
|
||||
lockTimeConfig := LockTimeConfig{
|
||||
MaxInterval: time.Hour,
|
||||
MinWait: 0 * time.Millisecond,
|
||||
MaxWait: 1 * time.Millisecond,
|
||||
}
|
||||
ctx := context.Background()
|
||||
actionName := "test-operation"
|
||||
|
||||
// Acquire lock so that when `LockExecuteAndReleaseWithRetries` runs, it is forced
|
||||
// to retry
|
||||
err := sl.acquireForRelease(ctx, actionName, lockTimeConfig.MaxInterval)
|
||||
require.NoError(t, err)
|
||||
|
||||
wgRetries := sync.WaitGroup{}
|
||||
wgRetries.Add(expectedRetries)
|
||||
wgRelease := sync.WaitGroup{}
|
||||
wgRelease.Add(1)
|
||||
wgCompleted := sync.WaitGroup{}
|
||||
wgCompleted.Add(1)
|
||||
|
||||
onRetryFn := func(int) error {
|
||||
retries++
|
||||
wgRetries.Done()
|
||||
if retries == expectedRetries {
|
||||
// When we reach `expectedRetries`, wait for the lock to be released
|
||||
// to guarantee that next try will succeed
|
||||
wgRelease.Wait()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := sl.LockExecuteAndReleaseWithRetries(ctx, actionName, lockTimeConfig, fn, onRetryFn)
|
||||
require.NoError(t, err)
|
||||
wgCompleted.Done()
|
||||
}()
|
||||
|
||||
// Wait to release the lock until `LockExecuteAndReleaseWithRetries` has retried `expectedRetries` times.
|
||||
wgRetries.Wait()
|
||||
err = sl.releaseLock(ctx, actionName)
|
||||
require.NoError(t, err)
|
||||
wgRelease.Done()
|
||||
|
||||
// `LockExecuteAndReleaseWithRetries` has run completely.
|
||||
// Check that it had to retry because the lock was already taken.
|
||||
wgCompleted.Wait()
|
||||
require.Equal(t, expectedRetries, retries)
|
||||
require.Equal(t, 1, funcRuns)
|
||||
}
|
||||
|
@ -3,8 +3,10 @@ package registry
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/serverlock"
|
||||
"github.com/grafana/grafana/pkg/infra/slugify"
|
||||
"github.com/grafana/grafana/pkg/services/extsvcauth"
|
||||
"github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver/oasimpl"
|
||||
@ -14,6 +16,12 @@ import (
|
||||
|
||||
var _ extsvcauth.ExternalServiceRegistry = &Registry{}
|
||||
|
||||
var lockTimeConfig = serverlock.LockTimeConfig{
|
||||
MaxInterval: 2 * time.Minute,
|
||||
MinWait: 1 * time.Second,
|
||||
MaxWait: 5 * time.Second,
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
features featuremgmt.FeatureToggles
|
||||
logger log.Logger
|
||||
@ -22,9 +30,10 @@ type Registry struct {
|
||||
|
||||
extSvcProviders map[string]extsvcauth.AuthProvider
|
||||
lock sync.Mutex
|
||||
serverLock *serverlock.ServerLockService
|
||||
}
|
||||
|
||||
func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, features featuremgmt.FeatureToggles) *Registry {
|
||||
func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, serverLock *serverlock.ServerLockService, features featuremgmt.FeatureToggles) *Registry {
|
||||
return &Registry{
|
||||
extSvcProviders: map[string]extsvcauth.AuthProvider{},
|
||||
features: features,
|
||||
@ -32,6 +41,7 @@ func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvc
|
||||
logger: log.New("extsvcauth.registry"),
|
||||
oauthReg: oauthServer,
|
||||
saReg: saSvc,
|
||||
serverLock: serverLock,
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,7 +114,7 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
|
||||
r.logger.Debug("Routing External Service removal to the OAuth2Server", "service", name)
|
||||
return r.oauthReg.RemoveExternalService(ctx, name)
|
||||
default:
|
||||
return extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", provider)
|
||||
return extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,29 +122,42 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
|
||||
// it generates client_id, secrets and any additional provider specificities (ex: rsa keys). It also ensures that the
|
||||
// associated service account has the correct permissions.
|
||||
func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.ExternalServiceRegistration) (*extsvcauth.ExternalService, error) {
|
||||
// Record provider in case of removal
|
||||
r.lock.Lock()
|
||||
r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider
|
||||
r.lock.Unlock()
|
||||
var (
|
||||
errSave error
|
||||
extSvc *extsvcauth.ExternalService
|
||||
lockName = "ext-svc-save-" + cmd.Name
|
||||
)
|
||||
|
||||
switch cmd.AuthProvider {
|
||||
case extsvcauth.ServiceAccounts:
|
||||
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) {
|
||||
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts)
|
||||
return nil, nil
|
||||
err := r.serverLock.LockExecuteAndReleaseWithRetries(ctx, lockName, lockTimeConfig, func(ctx context.Context) {
|
||||
// Record provider in case of removal
|
||||
r.lock.Lock()
|
||||
r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider
|
||||
r.lock.Unlock()
|
||||
|
||||
switch cmd.AuthProvider {
|
||||
case extsvcauth.ServiceAccounts:
|
||||
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) {
|
||||
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts)
|
||||
return
|
||||
}
|
||||
r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name)
|
||||
extSvc, errSave = r.saReg.SaveExternalService(ctx, cmd)
|
||||
case extsvcauth.OAuth2Server:
|
||||
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) {
|
||||
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth)
|
||||
return
|
||||
}
|
||||
r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name)
|
||||
extSvc, errSave = r.oauthReg.SaveExternalService(ctx, cmd)
|
||||
default:
|
||||
errSave = extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", cmd.AuthProvider)
|
||||
}
|
||||
r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name)
|
||||
return r.saReg.SaveExternalService(ctx, cmd)
|
||||
case extsvcauth.OAuth2Server:
|
||||
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) {
|
||||
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth)
|
||||
return nil, nil
|
||||
}
|
||||
r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name)
|
||||
return r.oauthReg.SaveExternalService(ctx, cmd)
|
||||
default:
|
||||
return nil, extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", cmd.AuthProvider)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return extSvc, errSave
|
||||
}
|
||||
|
||||
// retrieveExtSvcProviders fetches external services from store and map their associated provider
|
||||
|
@ -2,7 +2,6 @@ package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
@ -29,7 +28,6 @@ func setupTestEnv(t *testing.T) *TestEnv {
|
||||
oauthReg: env.oauthReg,
|
||||
saReg: env.saReg,
|
||||
extSvcProviders: map[string]extsvcauth.AuthProvider{},
|
||||
lock: sync.Mutex{},
|
||||
}
|
||||
return &env
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user