Auth: Convert SetDefaultOrgHook to PostLoginHook (#85649)

* Convert SetDefaultOrgHook to PostLoginHook
This commit is contained in:
Misi 2024-04-05 16:03:51 +02:00 committed by GitHub
parent 734d0111cb
commit 8796d2d307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 21 deletions

View File

@ -112,7 +112,7 @@ func (a *AnonDeviceService) untagDevice(ctx context.Context,
errD := a.anonStore.DeleteDevice(ctx, deviceID)
if errD != nil {
a.log.Debug("Failed to untag device", "error", err)
a.log.Debug("Failed to untag device", "error", errD)
}
}

View File

@ -110,7 +110,7 @@ func ProvideRegistration(
}
authnSvc.RegisterPostAuthHook(rbacSync.SyncPermissionsHook, 120)
authnSvc.RegisterPostAuthHook(orgSync.SetDefaultOrgHook, 130)
authnSvc.RegisterPostLoginHook(orgSync.SetDefaultOrgHook, 140)
return Registration{}
}

View File

@ -132,9 +132,9 @@ func (s *OrgSync) SyncOrgRolesHook(ctx context.Context, id *authn.Identity, _ *a
return nil
}
func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.Identity, r *authn.Request) error {
if s.cfg.LoginDefaultOrgId < 1 || currentIdentity == nil {
return nil
func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.Identity, r *authn.Request, err error) {
if s.cfg.LoginDefaultOrgId < 1 || currentIdentity == nil || err != nil {
return
}
ctxLogger := s.log.FromContext(ctx)
@ -142,33 +142,30 @@ func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.
namespace, identifier := currentIdentity.GetNamespacedID()
if namespace != identity.NamespaceUser {
ctxLogger.Debug("Skipping default org sync, not a user", "namespace", namespace)
return nil
return
}
userID, err := identity.IntIdentifier(namespace, identifier)
if err != nil {
ctxLogger.Debug("Skipping default org sync, invalid ID for identity", "id", currentIdentity.ID, "namespace", namespace, "err", err)
return nil
return
}
hasAssignedToOrg, err := s.validateUsingOrg(ctx, userID, s.cfg.LoginDefaultOrgId)
if err != nil {
ctxLogger.Error("Skipping default org sync, failed to validate user's organizations", "id", currentIdentity.ID, "err", err)
return nil
return
}
if !hasAssignedToOrg {
ctxLogger.Debug("Skipping default org sync, user is not assigned to org", "id", currentIdentity.ID, "org", s.cfg.LoginDefaultOrgId)
return nil
return
}
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: s.cfg.LoginDefaultOrgId}
if err := s.userService.SetUsingOrg(ctx, &cmd); err != nil {
ctxLogger.Error("Failed to set default org", "id", currentIdentity.ID, "err", err)
return err
if svcErr := s.userService.SetUsingOrg(ctx, &cmd); svcErr != nil {
ctxLogger.Error("Failed to set default org", "id", currentIdentity.ID, "err", svcErr)
}
return nil
}
func (s *OrgSync) validateUsingOrg(ctx context.Context, userID int64, orgID int64) (bool, error) {

View File

@ -134,8 +134,7 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting int64
identity *authn.Identity
setupMock func(*usertest.MockService, *orgtest.FakeOrgService)
wantErr bool
inputErr error
}{
{
name: "should set default org",
@ -157,6 +156,12 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting: -1,
identity: nil,
},
{
name: "should skip setting the default org when input err is not nil",
defaultOrgSetting: 2,
identity: &authn.Identity{ID: "user:1"},
inputErr: fmt.Errorf("error"),
},
{
name: "should skip setting the default org when identity is not a user",
defaultOrgSetting: 2,
@ -181,13 +186,12 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
},
},
{
name: "should return error when the user org update was unsuccessful",
name: "should skip the hook when the user org update was unsuccessful",
defaultOrgSetting: 2,
identity: &authn.Identity{ID: "user:1"},
setupMock: func(userService *usertest.MockService, orgService *orgtest.FakeOrgService) {
userService.On("SetUsingOrg", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
},
wantErr: true,
},
}
for _, tt := range testCases {
@ -214,9 +218,9 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
cfg: cfg,
}
if err := s.SetDefaultOrgHook(context.Background(), tt.identity, nil); (err != nil) != tt.wantErr {
t.Errorf("OrgSync.SetDefaultOrgHook() error = %v, wantErr %v", err, tt.wantErr)
}
s.SetDefaultOrgHook(context.Background(), tt.identity, nil, tt.inputErr)
userService.AssertExpectations(t)
})
}
}