AuthN: Support HA setups with External Service Account management (#78425)

* Lock when creating external service

* Add local lock back

* Improve function signature

* Define lockName separately to make it more explicit

* Update pkg/infra/serverlock/serverlock.go

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>

* Update pkg/infra/serverlock/serverlock.go

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>

---------

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
This commit is contained in:
Xavi Lacasa 2023-11-22 10:15:13 +01:00 committed by GitHub
parent 61553e1693
commit 72759be6ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 180 additions and 26 deletions

View File

@ -2,6 +2,8 @@ package serverlock
import (
"context"
"errors"
"math/rand"
"time"
"go.opentelemetry.io/otel/attribute"
@ -164,6 +166,74 @@ func (sl *ServerLockService) LockExecuteAndRelease(ctx context.Context, actionNa
return nil
}
// RetryOpt is a callback function called after each failed lock acquisition try.
// It gets the number of tries passed as an arg.
type RetryOpt func(int) error
type LockTimeConfig struct {
MaxInterval time.Duration // Duration after which we consider a lock to be dead and overtake it. Make sure this is big enough so that a server cannot acquire the lock while another server is processing.
MinWait time.Duration // Minimum time to wait before retrying to acquire the lock.
MaxWait time.Duration // Maximum time to wait before retrying to acquire the lock.
}
// LockExecuteAndReleaseWithRetries mimics LockExecuteAndRelease but waits for the lock to be released if it is already taken.
func (sl *ServerLockService) LockExecuteAndReleaseWithRetries(ctx context.Context, actionName string, timeConfig LockTimeConfig, fn func(ctx context.Context), retryOpts ...RetryOpt) error {
start := time.Now()
ctx, span := sl.tracer.Start(ctx, "ServerLockService.LockExecuteAndReleaseWithRetries")
span.SetAttributes(attribute.String("serverlock.actionName", actionName))
defer span.End()
ctxLogger := sl.log.FromContext(ctx)
ctxLogger.Debug("Start LockExecuteAndReleaseWithRetries", "actionName", actionName)
lockChecks := 0
for {
lockChecks++
err := sl.acquireForRelease(ctx, actionName, timeConfig.MaxInterval)
// could not get the lock
if err != nil {
var lockedErr *ServerLockExistsError
if errors.As(err, &lockedErr) {
// if the lock is already taken, wait and try again
if lockChecks == 1 { // only warn on first lock check
ctxLogger.Warn("another instance has the lock, waiting for it to be released", "actionName", actionName)
}
for _, op := range retryOpts {
if err := op(lockChecks); err != nil {
return err
}
}
time.Sleep(lockWait(timeConfig.MinWait, timeConfig.MaxWait))
continue
}
span.RecordError(err)
return err
}
// lock was acquired and released successfully
break
}
sl.executeFunc(ctx, actionName, fn)
if err := sl.releaseLock(ctx, actionName); err != nil {
span.RecordError(err)
ctxLogger.Error("Failed to release the lock", "error", err)
}
ctxLogger.Debug("LockExecuteAndReleaseWithRetries finished", "actionName", actionName, "duration", time.Since(start))
return nil
}
// generate a random duration between minWait and maxWait to ensure instances unlock gradually
func lockWait(minWait time.Duration, maxWait time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(maxWait-minWait)) + int64(minWait))
}
// acquireForRelease will check if the lock is already on the database, if it is, will check with maxInterval if it is
// timeouted. Returns nil error if the lock was acquired correctly
func (sl *ServerLockService) acquireForRelease(ctx context.Context, actionName string, maxInterval time.Duration) error {

View File

@ -2,6 +2,7 @@ package serverlock
import (
"context"
"sync"
"testing"
"time"
@ -19,11 +20,11 @@ func TestIntegrationServerLock_LockAndExecute(t *testing.T) {
atInterval := time.Hour
ctx := context.Background()
//this time `fn` should be executed
// this time `fn` should be executed
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
require.Equal(t, 1, counter)
//this should not execute `fn`
// this should not execute `fn`
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn))
require.Equal(t, 1, counter)
@ -62,3 +63,65 @@ func TestIntegrationServerLock_LockExecuteAndRelease(t *testing.T) {
require.Equal(t, 4, counter)
}
func TestIntegrationServerLock_LockExecuteAndReleaseWithRetries(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
sl := createTestableServerLock(t)
retries := 0
expectedRetries := 10
funcRuns := 0
fn := func(context.Context) {
funcRuns++
}
lockTimeConfig := LockTimeConfig{
MaxInterval: time.Hour,
MinWait: 0 * time.Millisecond,
MaxWait: 1 * time.Millisecond,
}
ctx := context.Background()
actionName := "test-operation"
// Acquire lock so that when `LockExecuteAndReleaseWithRetries` runs, it is forced
// to retry
err := sl.acquireForRelease(ctx, actionName, lockTimeConfig.MaxInterval)
require.NoError(t, err)
wgRetries := sync.WaitGroup{}
wgRetries.Add(expectedRetries)
wgRelease := sync.WaitGroup{}
wgRelease.Add(1)
wgCompleted := sync.WaitGroup{}
wgCompleted.Add(1)
onRetryFn := func(int) error {
retries++
wgRetries.Done()
if retries == expectedRetries {
// When we reach `expectedRetries`, wait for the lock to be released
// to guarantee that next try will succeed
wgRelease.Wait()
}
return nil
}
go func() {
err := sl.LockExecuteAndReleaseWithRetries(ctx, actionName, lockTimeConfig, fn, onRetryFn)
require.NoError(t, err)
wgCompleted.Done()
}()
// Wait to release the lock until `LockExecuteAndReleaseWithRetries` has retried `expectedRetries` times.
wgRetries.Wait()
err = sl.releaseLock(ctx, actionName)
require.NoError(t, err)
wgRelease.Done()
// `LockExecuteAndReleaseWithRetries` has run completely.
// Check that it had to retry because the lock was already taken.
wgCompleted.Wait()
require.Equal(t, expectedRetries, retries)
require.Equal(t, 1, funcRuns)
}

View File

@ -3,8 +3,10 @@ package registry
import (
"context"
"sync"
"time"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/slugify"
"github.com/grafana/grafana/pkg/services/extsvcauth"
"github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver/oasimpl"
@ -14,6 +16,12 @@ import (
var _ extsvcauth.ExternalServiceRegistry = &Registry{}
var lockTimeConfig = serverlock.LockTimeConfig{
MaxInterval: 2 * time.Minute,
MinWait: 1 * time.Second,
MaxWait: 5 * time.Second,
}
type Registry struct {
features featuremgmt.FeatureToggles
logger log.Logger
@ -22,9 +30,10 @@ type Registry struct {
extSvcProviders map[string]extsvcauth.AuthProvider
lock sync.Mutex
serverLock *serverlock.ServerLockService
}
func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, features featuremgmt.FeatureToggles) *Registry {
func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, serverLock *serverlock.ServerLockService, features featuremgmt.FeatureToggles) *Registry {
return &Registry{
extSvcProviders: map[string]extsvcauth.AuthProvider{},
features: features,
@ -32,6 +41,7 @@ func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvc
logger: log.New("extsvcauth.registry"),
oauthReg: oauthServer,
saReg: saSvc,
serverLock: serverLock,
}
}
@ -104,7 +114,7 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
r.logger.Debug("Routing External Service removal to the OAuth2Server", "service", name)
return r.oauthReg.RemoveExternalService(ctx, name)
default:
return extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", provider)
return extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", provider)
}
}
@ -112,29 +122,42 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
// it generates client_id, secrets and any additional provider specificities (ex: rsa keys). It also ensures that the
// associated service account has the correct permissions.
func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.ExternalServiceRegistration) (*extsvcauth.ExternalService, error) {
// Record provider in case of removal
r.lock.Lock()
r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider
r.lock.Unlock()
var (
errSave error
extSvc *extsvcauth.ExternalService
lockName = "ext-svc-save-" + cmd.Name
)
switch cmd.AuthProvider {
case extsvcauth.ServiceAccounts:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts)
return nil, nil
err := r.serverLock.LockExecuteAndReleaseWithRetries(ctx, lockName, lockTimeConfig, func(ctx context.Context) {
// Record provider in case of removal
r.lock.Lock()
r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider
r.lock.Unlock()
switch cmd.AuthProvider {
case extsvcauth.ServiceAccounts:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts)
return
}
r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name)
extSvc, errSave = r.saReg.SaveExternalService(ctx, cmd)
case extsvcauth.OAuth2Server:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth)
return
}
r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name)
extSvc, errSave = r.oauthReg.SaveExternalService(ctx, cmd)
default:
errSave = extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", cmd.AuthProvider)
}
r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name)
return r.saReg.SaveExternalService(ctx, cmd)
case extsvcauth.OAuth2Server:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth)
return nil, nil
}
r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name)
return r.oauthReg.SaveExternalService(ctx, cmd)
default:
return nil, extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", cmd.AuthProvider)
})
if err != nil {
return nil, err
}
return extSvc, errSave
}
// retrieveExtSvcProviders fetches external services from store and map their associated provider

View File

@ -2,7 +2,6 @@ package registry
import (
"context"
"sync"
"testing"
"github.com/grafana/grafana/pkg/infra/log"
@ -29,7 +28,6 @@ func setupTestEnv(t *testing.T) *TestEnv {
oauthReg: env.oauthReg,
saReg: env.saReg,
extSvcProviders: map[string]extsvcauth.AuthProvider{},
lock: sync.Mutex{},
}
return &env
}