AuthN: Extract enable disabled users logic to its own hook (#63628)

This commit is contained in:
Karl Persson 2023-02-23 13:06:06 +01:00 committed by GitHub
parent 1406feb03c
commit 16b416b88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 11 deletions

View File

@ -144,6 +144,7 @@ func ProvideService(
userSyncService := sync.ProvideUserSync(userService, userProtectionService, authInfoService, quotaService)
orgUserSyncService := sync.ProvideOrgSync(userService, orgService, accessControlService)
s.RegisterPostAuthHook(userSyncService.SyncUserHook, 10)
s.RegisterPostAuthHook(userSyncService.EnableDisabledUserHook, 20)
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgRolesHook, 30)
s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 40)

View File

@ -159,6 +159,23 @@ func (s *UserSync) SyncLastSeenHook(ctx context.Context, identity *authn.Identit
return nil
}
func (s *UserSync) EnableDisabledUserHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
if !identity.ClientParams.EnableDisabledUsers {
return nil
}
if !identity.IsDisabled {
return nil
}
namespace, id := identity.NamespacedID()
if namespace != authn.NamespaceUser {
return nil
}
return s.userService.Disable(ctx, &user.DisableUserCommand{UserID: id, IsDisabled: false})
}
func (s *UserSync) upsertAuthConnection(ctx context.Context, userID int64, identity *authn.Identity, createConnection bool) error {
if identity.AuthModule == "" {
return nil
@ -217,17 +234,6 @@ func (s *UserSync) updateUserAttributes(ctx context.Context, usr *user.User, id
}
}
// FIXME(kalleep): Should this be its own hook?
if usr.IsDisabled && id.ClientParams.EnableDisabledUsers {
usr.IsDisabled = false
if errDisableUser := s.userService.Disable(
ctx,
&user.DisableUserCommand{UserID: usr.ID, IsDisabled: false},
); errDisableUser != nil {
return errDisableUser
}
}
// Sync isGrafanaAdmin permission
if id.IsGrafanaAdmin != nil && *id.IsGrafanaAdmin != usr.IsAdmin {
usr.IsAdmin = *id.IsGrafanaAdmin

View File

@ -4,6 +4,7 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/services/authn"
@ -471,3 +472,66 @@ func TestUserSync_FetchSyncedUserHook(t *testing.T) {
})
}
}
func TestUserSync_EnableDisabledUserHook(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
enableUser bool
}
tests := []testCase{
{
desc: "should skip if correct flag is not set",
identity: &authn.Identity{
ID: authn.NamespacedID(authn.NamespaceUser, 1),
IsDisabled: true,
ClientParams: authn.ClientParams{EnableDisabledUsers: false},
},
enableUser: false,
},
{
desc: "should skip if identity is not disabled",
identity: &authn.Identity{
ID: authn.NamespacedID(authn.NamespaceUser, 1),
IsDisabled: false,
ClientParams: authn.ClientParams{EnableDisabledUsers: true},
},
enableUser: false,
},
{
desc: "should skip if identity is not a user",
identity: &authn.Identity{
ID: authn.NamespacedID(authn.NamespaceAPIKey, 1),
IsDisabled: true,
ClientParams: authn.ClientParams{EnableDisabledUsers: true},
},
enableUser: false,
},
{
desc: "should enabled disabled user",
identity: &authn.Identity{
ID: authn.NamespacedID(authn.NamespaceUser, 1),
IsDisabled: true,
ClientParams: authn.ClientParams{EnableDisabledUsers: true},
},
enableUser: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
userSvc := usertest.NewUserServiceFake()
called := false
userSvc.DisableFn = func(ctx context.Context, cmd *user.DisableUserCommand) error {
called = true
return nil
}
s := UserSync{userService: userSvc}
err := s.EnableDisabledUserHook(context.Background(), tt.identity, nil)
require.NoError(t, err)
assert.Equal(t, tt.enableUser, called)
})
}
}

View File

@ -17,6 +17,7 @@ type FakeUserService struct {
GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error)
CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error)
DisableFn func(ctx context.Context, cmd *user.DisableUserCommand) error
counter int
}
@ -96,6 +97,9 @@ func (f *FakeUserService) Search(ctx context.Context, query *user.SearchUsersQue
}
func (f *FakeUserService) Disable(ctx context.Context, cmd *user.DisableUserCommand) error {
if f.DisableFn != nil {
return f.DisableFn(ctx, cmd)
}
return f.ExpectedError
}