Authn: external identity sync (#73461)

* Authn: Add interface for external identity sync

This interface is implemented by authnimpl.Service and just triggers PostAuthHooks and skipping last seen update by default

* Authn: Add SyncIdentity to fake and add a new mock
This commit is contained in:
Karl Persson 2023-08-18 11:11:44 +02:00 committed by GitHub
parent 878e94ae25
commit 124e0efe1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 80 additions and 9 deletions

View File

@ -40,6 +40,7 @@ import (
"github.com/grafana/grafana/pkg/services/annotations/annotationsimpl"
"github.com/grafana/grafana/pkg/services/apikey/apikeyimpl"
"github.com/grafana/grafana/pkg/services/auth/jwt"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authnimpl"
"github.com/grafana/grafana/pkg/services/certgenerator"
"github.com/grafana/grafana/pkg/services/cleanup"
@ -350,6 +351,8 @@ var wireBasicSet = wire.NewSet(
tagimpl.ProvideService,
wire.Bind(new(tag.Service), new(*tagimpl.Service)),
authnimpl.ProvideService,
wire.Bind(new(authn.Service), new(*authnimpl.Service)),
wire.Bind(new(authn.IdentitySynchronizer), new(*authnimpl.Service)),
supportbundlesimpl.ProvideService,
grafanaapiserver.WireSet,
oasimpl.ProvideService,

View File

@ -84,6 +84,10 @@ type Service interface {
RegisterClient(c Client)
}
type IdentitySynchronizer interface {
SyncIdentity(ctx context.Context, identity *Identity) error
}
type Client interface {
// Name returns the name of a client
Name() string

View File

@ -50,6 +50,9 @@ var (
// make sure service implements authn.Service interface
var _ authn.Service = new(Service)
// make sure service implements authn.IdentitySynchronizer interface
var _ authn.IdentitySynchronizer = new(Service)
func ProvideService(
cfg *setting.Cfg, tracer tracing.Tracer,
orgService org.Service, sessionService auth.UserTokenService,
@ -65,7 +68,7 @@ func ProvideService(
socialService social.Service, cache *remotecache.RemoteCache,
ldapService service.LDAP, registerer prometheus.Registerer,
signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server,
) authn.Service {
) *Service {
s := &Service{
log: log.New("authn.service"),
cfg: cfg,
@ -228,11 +231,9 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
return nil, err
}
for _, hook := range s.postAuthHooks.items {
if err := hook.v(ctx, identity, r); err != nil {
s.log.FromContext(ctx).Warn("Failed to run post auth hook", "client", c.Name(), "id", identity.ID, "error", err)
return nil, err
}
if err := s.runPostAuthHooks(ctx, identity, r); err != nil {
s.log.FromContext(ctx).Warn("Failed to run post auth hook", "client", c.Name(), "id", identity.ID, "error", err)
return nil, err
}
if identity.IsDisabled {
@ -249,6 +250,15 @@ func (s *Service) authenticate(ctx context.Context, c authn.Client, r *authn.Req
return identity, nil
}
func (s *Service) runPostAuthHooks(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
for _, hook := range s.postAuthHooks.items {
if err := hook.v(ctx, identity, r); err != nil {
return err
}
}
return nil
}
func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn, priority uint) {
s.postAuthHooks.insert(hook, priority)
}
@ -332,6 +342,13 @@ func (s *Service) RegisterClient(c authn.Client) {
}
}
func (s *Service) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
r := &authn.Request{OrgID: identity.OrgID}
// hack to not update last seen on external syncs
r.SetMeta(authn.MetaKeyIsLogin, "true")
return s.runPostAuthHooks(ctx, identity, r)
}
func orgIDFromRequest(r *authn.Request) int64 {
if r.HTTPRequest == nil {
return 0

View File

@ -7,6 +7,7 @@ import (
)
var _ authn.Service = new(FakeService)
var _ authn.IdentitySynchronizer = new(FakeService)
type FakeService struct {
ExpectedErr error
@ -67,6 +68,10 @@ func (f *FakeService) RedirectURL(ctx context.Context, client string, r *authn.R
func (f *FakeService) RegisterClient(c authn.Client) {}
func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
return f.ExpectedErr
}
var _ authn.ContextAwareClient = new(FakeClient)
type FakeClient struct {

View File

@ -6,6 +6,44 @@ import (
"github.com/grafana/grafana/pkg/services/authn"
)
var _ authn.Service = new(MockService)
var _ authn.IdentitySynchronizer = new(MockService)
type MockService struct {
SyncIdentityFunc func(ctx context.Context, identity *authn.Identity) error
}
func (m *MockService) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
panic("unimplemented")
}
func (m *MockService) Login(ctx context.Context, client string, r *authn.Request) (*authn.Identity, error) {
panic("unimplemented")
}
func (m *MockService) RedirectURL(ctx context.Context, client string, r *authn.Request) (*authn.Redirect, error) {
panic("unimplemented")
}
func (m *MockService) RegisterClient(c authn.Client) {
panic("unimplemented")
}
func (m *MockService) RegisterPostAuthHook(hook authn.PostAuthHookFn, priority uint) {
panic("unimplemented")
}
func (m *MockService) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority uint) {
panic("unimplemented")
}
func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
if m.SyncIdentityFunc != nil {
return m.SyncIdentityFunc(ctx, identity)
}
return nil
}
var _ authn.HookClient = new(MockClient)
var _ authn.ContextAwareClient = new(MockClient)

View File

@ -16,9 +16,10 @@ type FakeUserService struct {
ExpectedUserProfileDTOs []*user.UserProfileDTO
ExpectedUsageStats map[string]interface{}
GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error)
CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error)
DisableFn func(ctx context.Context, cmd *user.DisableUserCommand) error
GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error)
CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error)
DisableFn func(ctx context.Context, cmd *user.DisableUserCommand) error
BatchDisableUsersFn func(ctx context.Context, cmd *user.BatchDisableUsersCommand) error
counter int
}
@ -105,6 +106,9 @@ func (f *FakeUserService) Disable(ctx context.Context, cmd *user.DisableUserComm
}
func (f *FakeUserService) BatchDisableUsers(ctx context.Context, cmd *user.BatchDisableUsersCommand) error {
if f.BatchDisableUsersFn != nil {
return f.BatchDisableUsersFn(ctx, cmd)
}
return f.ExpectedError
}