AuthN: support priority for post auth and post login hooks (#62208)

* AuthN: store post auth hooks in a priority list and update registration
function to take a priority

* AuthN: store post login hooks in a priority list and update registration function to take a priority

* AuthN: Change priority for sync user
This commit is contained in:
Karl Persson 2023-01-27 11:40:12 +01:00 committed by GitHub
parent 14185ba819
commit 3447ad2602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 25 deletions

View File

@ -54,12 +54,14 @@ type PostLoginHookFn func(ctx context.Context, identity *Identity, r *Request, e
type Service interface {
// Authenticate authenticates a request
Authenticate(ctx context.Context, r *Request) (*Identity, error)
// RegisterPostAuthHook registers a hook that is called after a successful authentication.
RegisterPostAuthHook(hook PostAuthHookFn)
// RegisterPostAuthHook registers a hook with a priority that is called after a successful authentication.
// A lower number means higher priority.
RegisterPostAuthHook(hook PostAuthHookFn, priority uint)
// Login authenticates a request and creates a session on successful authentication.
Login(ctx context.Context, client string, r *Request) (*Identity, error)
// RegisterPostLoginHook registers a hook that that is called after a login request.
RegisterPostLoginHook(hook PostLoginHookFn)
// A lower number means higher priority.
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
RedirectURL(ctx context.Context, client string, r *Request) (string, error)
}

View File

@ -60,7 +60,8 @@ func ProvideService(
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracer,
sessionService: sessionService,
postAuthHooks: []authn.PostAuthHookFn{},
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
}
s.RegisterClient(clients.ProvideRender(userService, renderService))
@ -69,7 +70,7 @@ func ProvideService(
if cfg.LoginCookieName != "" {
sessionClient := clients.ProvideSession(sessionService, userService, cfg.LoginCookieName, cfg.LoginMaxLifetime)
s.RegisterClient(sessionClient)
s.RegisterPostAuthHook(sessionClient.RefreshTokenHook)
s.RegisterPostAuthHook(sessionClient.RefreshTokenHook, 20)
}
if s.cfg.AnonymousEnabled {
@ -118,13 +119,13 @@ func ProvideService(
// FIXME (jguer): move to User package
userSyncService := sync.ProvideUserSync(userService, userProtectionService, authInfoService, quotaService)
orgUserSyncService := sync.ProvideOrgSync(userService, orgService, accessControlService)
s.RegisterPostAuthHook(userSyncService.SyncUser)
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgUser)
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen)
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen)
s.RegisterPostAuthHook(userSyncService.SyncUser, 10)
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgUser, 30)
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen, 40)
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen, 50)
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
s.RegisterPostAuthHook(sync.ProvideOauthTokenSync(oauthTokenService, sessionService).SyncOauthToken)
s.RegisterPostAuthHook(sync.ProvideOauthTokenSync(oauthTokenService, sessionService).SyncOauthToken, 60)
}
return s
@ -141,9 +142,9 @@ type Service struct {
sessionService auth.UserTokenService
// postAuthHooks are called after a successful authentication. They can modify the identity.
postAuthHooks []authn.PostAuthHookFn
postAuthHooks *queue[authn.PostAuthHookFn]
// postLoginHooks are called after a login request is performed, both for failing and successful requests.
postLoginHooks []authn.PostLoginHookFn
postLoginHooks *queue[authn.PostLoginHookFn]
}
func (s *Service) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
@ -182,8 +183,8 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
return nil, err
}
for _, hook := range s.postAuthHooks {
if err := hook(ctx, identity, r); err != nil {
for _, hook := range s.postAuthHooks.items {
if err := hook.v(ctx, identity, r); err != nil {
s.log.FromContext(ctx).Warn("post auth hook failed", "error", err, "id", identity)
return nil, err
}
@ -196,14 +197,14 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
return identity, nil
}
func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn) {
s.postAuthHooks = append(s.postAuthHooks, hook)
func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn, priority uint) {
s.postAuthHooks.insert(hook, priority)
}
func (s *Service) Login(ctx context.Context, client string, r *authn.Request) (identity *authn.Identity, err error) {
defer func() {
for _, hook := range s.postLoginHooks {
hook(ctx, identity, r, err)
for _, hook := range s.postLoginHooks.items {
hook.v(ctx, identity, r, err)
}
}()
@ -239,8 +240,8 @@ func (s *Service) Login(ctx context.Context, client string, r *authn.Request) (i
return identity, nil
}
func (s *Service) RegisterPostLoginHook(hook authn.PostLoginHookFn) {
s.postLoginHooks = append(s.postLoginHooks, hook)
func (s *Service) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority uint) {
s.postLoginHooks.insert(hook, priority)
}
func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (string, error) {

View File

@ -292,11 +292,13 @@ func setupTests(t *testing.T, opts ...func(svc *Service)) *Service {
t.Helper()
s := &Service{
log: log.NewNopLogger(),
cfg: setting.NewCfg(),
clientQueue: newQueue[authn.ContextAwareClient](),
clients: map[string]authn.Client{},
tracer: tracing.InitializeTracerForTest(),
log: log.NewNopLogger(),
cfg: setting.NewCfg(),
clients: map[string]authn.Client{},
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracing.InitializeTracerForTest(),
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
}
for _, o := range opts {