AuthN: add support for client specific hooks (#62863)

* AuthN: Add HookClient interface

* AuthN: Check if client implement authn.HookClient and call the hook if
it does

* AuthN: Convert refresh token hook into a client hook
This commit is contained in:
Karl Persson 2023-02-03 14:35:17 +01:00 committed by GitHub
parent 180a587f70
commit 6840cc11ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 9 deletions

View File

@ -78,6 +78,8 @@ type Client interface {
Authenticate(ctx context.Context, r *Request) (*Identity, error)
}
// ContextAwareClient is an optional interface that auth client can implement.
// Clients that implements this interface will be tried during request authentication
type ContextAwareClient interface {
Client
// Test should return true if client can be used to authenticate request
@ -86,6 +88,17 @@ type ContextAwareClient interface {
Priority() uint
}
// HookClient is an optional interface that auth clients can implement.
// Clients that implements this interface can specify an auth hook that will
// be called only for that client.
type HookClient interface {
Client
Hook(ctx context.Context, identity *Identity, r *Request) error
}
// RedirectClient is an optional interface that auth clients can implement.
// Clients that implements this interface can be used to generate redirect urls
// for authentication flows, e.g. oauth clients
type RedirectClient interface {
Client
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)

View File

@ -70,9 +70,7 @@ func ProvideService(
s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService))
if cfg.LoginCookieName != "" {
sessionClient := clients.ProvideSession(sessionService, userService, cfg.LoginCookieName, cfg.LoginMaxLifetime)
s.RegisterClient(sessionClient)
s.RegisterPostAuthHook(sessionClient.RefreshTokenHook, 20)
s.RegisterClient(clients.ProvideSession(sessionService, userService, cfg.LoginCookieName, cfg.LoginMaxLifetime))
}
if s.cfg.AnonymousEnabled {
@ -175,7 +173,6 @@ func (s *Service) Authenticate(ctx context.Context, r *authn.Request) (*authn.Id
if item.v.Test(ctx, r) {
identity, err := s.authenticate(ctx, item.v, r)
if err != nil {
s.log.Warn("failed to authenticate", "client", item.v.Name(), "err", err)
authErr = multierror.Append(authErr, err)
// try next
continue
@ -204,7 +201,7 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
for _, hook := range s.postAuthHooks.items {
if err := hook.v(ctx, identity, r); err != nil {
s.log.FromContext(ctx).Warn("post auth hook failed", "error", err, "id", identity)
s.log.FromContext(ctx).Warn("post auth hook failed", "error", err, "client", c.Name(), "id", identity.ID)
return nil, err
}
}
@ -213,6 +210,13 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
return nil, errDisabledIdentity.Errorf("identity is disabled")
}
if hc, ok := c.(authn.HookClient); ok {
if err := hc.Hook(ctx, identity, r); err != nil {
s.log.FromContext(ctx).Warn("post client auth hook failed", "error", err, "client", c.Name(), "id", identity.ID)
return nil, err
}
}
return identity, nil
}

View File

@ -9,6 +9,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
@ -107,7 +108,7 @@ func TestService_Authenticate(t *testing.T) {
}
}
func TestService_Authenticate_OrgID(t *testing.T) {
func TestService_OrgID(t *testing.T) {
type TestCase struct {
desc string
req *authn.Request
@ -168,6 +169,28 @@ func TestService_Authenticate_OrgID(t *testing.T) {
}
}
func TestService_HookClient(t *testing.T) {
hookCalled := false
s := setupTests(t, func(svc *Service) {
svc.RegisterClient(&authntest.MockClient{
AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
return &authn.Identity{}, nil
},
TestFunc: func(ctx context.Context, r *authn.Request) bool {
return true
},
HookFunc: func(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
hookCalled = true
return nil
},
})
})
_, _ = s.Authenticate(context.Background(), &authn.Request{})
require.True(t, hookCalled)
}
func TestService_Login(t *testing.T) {
type TestCase struct {
desc string

View File

@ -6,6 +6,7 @@ import (
"github.com/grafana/grafana/pkg/services/authn"
)
var _ authn.HookClient = new(MockClient)
var _ authn.ContextAwareClient = new(MockClient)
type MockClient struct {
@ -13,6 +14,7 @@ type MockClient struct {
AuthenticateFunc func(ctx context.Context, r *authn.Request) (*authn.Identity, error)
TestFunc func(ctx context.Context, r *authn.Request) bool
PriorityFunc func() uint
HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error
}
func (m MockClient) Name() string {
@ -43,6 +45,13 @@ func (m MockClient) Priority() uint {
return 0
}
func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
if m.HookFunc != nil {
return m.HookFunc(ctx, identity, r)
}
return nil
}
var _ authn.ProxyClient = new(MockProxyClient)
type MockProxyClient struct {

View File

@ -15,6 +15,7 @@ import (
"github.com/grafana/grafana/pkg/web"
)
var _ authn.HookClient = new(Session)
var _ authn.ContextAwareClient = new(Session)
func ProvideSession(sessionService auth.UserTokenService, userService user.Service,
@ -87,7 +88,7 @@ func (s *Session) Priority() uint {
return 60
}
func (s *Session) RefreshTokenHook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
if identity.SessionToken == nil {
return nil
}

View File

@ -141,7 +141,7 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) {
f.Status = statusCode
}
func TestSession_RefreshHook(t *testing.T) {
func TestSession_Hook(t *testing.T) {
s := ProvideSession(&authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
token.UnhashedToken = "new-token"
@ -168,7 +168,7 @@ func TestSession_RefreshHook(t *testing.T) {
Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter),
}
err := s.RefreshTokenHook(context.Background(), sampleID, resp)
err := s.Hook(context.Background(), sampleID, resp)
require.NoError(t, err)
resp.Resp.WriteHeader(201)