diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index bf2efd2c6f7..c9d23cd6f8f 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -78,6 +78,8 @@ type Client interface { Authenticate(ctx context.Context, r *Request) (*Identity, error) } +// ContextAwareClient is an optional interface that auth client can implement. +// Clients that implements this interface will be tried during request authentication type ContextAwareClient interface { Client // Test should return true if client can be used to authenticate request @@ -86,6 +88,17 @@ type ContextAwareClient interface { Priority() uint } +// HookClient is an optional interface that auth clients can implement. +// Clients that implements this interface can specify an auth hook that will +// be called only for that client. +type HookClient interface { + Client + Hook(ctx context.Context, identity *Identity, r *Request) error +} + +// RedirectClient is an optional interface that auth clients can implement. +// Clients that implements this interface can be used to generate redirect urls +// for authentication flows, e.g. oauth clients type RedirectClient interface { Client RedirectURL(ctx context.Context, r *Request) (*Redirect, error) diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index 0b41e3bff43..b69f9dcf073 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -70,9 +70,7 @@ func ProvideService( s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService)) if cfg.LoginCookieName != "" { - sessionClient := clients.ProvideSession(sessionService, userService, cfg.LoginCookieName, cfg.LoginMaxLifetime) - s.RegisterClient(sessionClient) - s.RegisterPostAuthHook(sessionClient.RefreshTokenHook, 20) + s.RegisterClient(clients.ProvideSession(sessionService, userService, cfg.LoginCookieName, cfg.LoginMaxLifetime)) } if s.cfg.AnonymousEnabled { @@ -175,7 +173,6 @@ func (s *Service) Authenticate(ctx context.Context, r *authn.Request) (*authn.Id if item.v.Test(ctx, r) { identity, err := s.authenticate(ctx, item.v, r) if err != nil { - s.log.Warn("failed to authenticate", "client", item.v.Name(), "err", err) authErr = multierror.Append(authErr, err) // try next continue @@ -204,7 +201,7 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req for _, hook := range s.postAuthHooks.items { if err := hook.v(ctx, identity, r); err != nil { - s.log.FromContext(ctx).Warn("post auth hook failed", "error", err, "id", identity) + s.log.FromContext(ctx).Warn("post auth hook failed", "error", err, "client", c.Name(), "id", identity.ID) return nil, err } } @@ -213,6 +210,13 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req return nil, errDisabledIdentity.Errorf("identity is disabled") } + if hc, ok := c.(authn.HookClient); ok { + if err := hc.Hook(ctx, identity, r); err != nil { + s.log.FromContext(ctx).Warn("post client auth hook failed", "error", err, "client", c.Name(), "id", identity.ID) + return nil, err + } + } + return identity, nil } diff --git a/pkg/services/authn/authnimpl/service_test.go b/pkg/services/authn/authnimpl/service_test.go index 5e6f8631d51..fe724d8c282 100644 --- a/pkg/services/authn/authnimpl/service_test.go +++ b/pkg/services/authn/authnimpl/service_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" @@ -107,7 +108,7 @@ func TestService_Authenticate(t *testing.T) { } } -func TestService_Authenticate_OrgID(t *testing.T) { +func TestService_OrgID(t *testing.T) { type TestCase struct { desc string req *authn.Request @@ -168,6 +169,28 @@ func TestService_Authenticate_OrgID(t *testing.T) { } } +func TestService_HookClient(t *testing.T) { + hookCalled := false + + s := setupTests(t, func(svc *Service) { + svc.RegisterClient(&authntest.MockClient{ + AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) { + return &authn.Identity{}, nil + }, + TestFunc: func(ctx context.Context, r *authn.Request) bool { + return true + }, + HookFunc: func(ctx context.Context, identity *authn.Identity, r *authn.Request) error { + hookCalled = true + return nil + }, + }) + }) + + _, _ = s.Authenticate(context.Background(), &authn.Request{}) + require.True(t, hookCalled) +} + func TestService_Login(t *testing.T) { type TestCase struct { desc string diff --git a/pkg/services/authn/authntest/mock.go b/pkg/services/authn/authntest/mock.go index 98f5ae5c6c2..ea6edfdd92f 100644 --- a/pkg/services/authn/authntest/mock.go +++ b/pkg/services/authn/authntest/mock.go @@ -6,6 +6,7 @@ import ( "github.com/grafana/grafana/pkg/services/authn" ) +var _ authn.HookClient = new(MockClient) var _ authn.ContextAwareClient = new(MockClient) type MockClient struct { @@ -13,6 +14,7 @@ type MockClient struct { AuthenticateFunc func(ctx context.Context, r *authn.Request) (*authn.Identity, error) TestFunc func(ctx context.Context, r *authn.Request) bool PriorityFunc func() uint + HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error } func (m MockClient) Name() string { @@ -43,6 +45,13 @@ func (m MockClient) Priority() uint { return 0 } +func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { + if m.HookFunc != nil { + return m.HookFunc(ctx, identity, r) + } + return nil +} + var _ authn.ProxyClient = new(MockProxyClient) type MockProxyClient struct { diff --git a/pkg/services/authn/clients/session.go b/pkg/services/authn/clients/session.go index 2af8d35bc72..d57daa71279 100644 --- a/pkg/services/authn/clients/session.go +++ b/pkg/services/authn/clients/session.go @@ -15,6 +15,7 @@ import ( "github.com/grafana/grafana/pkg/web" ) +var _ authn.HookClient = new(Session) var _ authn.ContextAwareClient = new(Session) func ProvideSession(sessionService auth.UserTokenService, userService user.Service, @@ -87,7 +88,7 @@ func (s *Session) Priority() uint { return 60 } -func (s *Session) RefreshTokenHook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { +func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { if identity.SessionToken == nil { return nil } diff --git a/pkg/services/authn/clients/session_test.go b/pkg/services/authn/clients/session_test.go index 4809c03fb54..b8b363987d4 100644 --- a/pkg/services/authn/clients/session_test.go +++ b/pkg/services/authn/clients/session_test.go @@ -141,7 +141,7 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) { f.Status = statusCode } -func TestSession_RefreshHook(t *testing.T) { +func TestSession_Hook(t *testing.T) { s := ProvideSession(&authtest.FakeUserAuthTokenService{ TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { token.UnhashedToken = "new-token" @@ -168,7 +168,7 @@ func TestSession_RefreshHook(t *testing.T) { Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter), } - err := s.RefreshTokenHook(context.Background(), sampleID, resp) + err := s.Hook(context.Background(), sampleID, resp) require.NoError(t, err) resp.Resp.WriteHeader(201)