Auth: Convert SetDefaultOrgHook to PostLoginHook (#85649)

* Convert SetDefaultOrgHook to PostLoginHook
This commit is contained in:
Misi 2024-04-05 16:03:51 +02:00 committed by GitHub
parent 734d0111cb
commit 8796d2d307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 21 deletions

View File

@ -112,7 +112,7 @@ func (a *AnonDeviceService) untagDevice(ctx context.Context,
errD := a.anonStore.DeleteDevice(ctx, deviceID) errD := a.anonStore.DeleteDevice(ctx, deviceID)
if errD != nil { if errD != nil {
a.log.Debug("Failed to untag device", "error", err) a.log.Debug("Failed to untag device", "error", errD)
} }
} }

View File

@ -110,7 +110,7 @@ func ProvideRegistration(
} }
authnSvc.RegisterPostAuthHook(rbacSync.SyncPermissionsHook, 120) authnSvc.RegisterPostAuthHook(rbacSync.SyncPermissionsHook, 120)
authnSvc.RegisterPostAuthHook(orgSync.SetDefaultOrgHook, 130) authnSvc.RegisterPostLoginHook(orgSync.SetDefaultOrgHook, 140)
return Registration{} return Registration{}
} }

View File

@ -132,9 +132,9 @@ func (s *OrgSync) SyncOrgRolesHook(ctx context.Context, id *authn.Identity, _ *a
return nil return nil
} }
func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.Identity, r *authn.Request) error { func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.Identity, r *authn.Request, err error) {
if s.cfg.LoginDefaultOrgId < 1 || currentIdentity == nil { if s.cfg.LoginDefaultOrgId < 1 || currentIdentity == nil || err != nil {
return nil return
} }
ctxLogger := s.log.FromContext(ctx) ctxLogger := s.log.FromContext(ctx)
@ -142,33 +142,30 @@ func (s *OrgSync) SetDefaultOrgHook(ctx context.Context, currentIdentity *authn.
namespace, identifier := currentIdentity.GetNamespacedID() namespace, identifier := currentIdentity.GetNamespacedID()
if namespace != identity.NamespaceUser { if namespace != identity.NamespaceUser {
ctxLogger.Debug("Skipping default org sync, not a user", "namespace", namespace) ctxLogger.Debug("Skipping default org sync, not a user", "namespace", namespace)
return nil return
} }
userID, err := identity.IntIdentifier(namespace, identifier) userID, err := identity.IntIdentifier(namespace, identifier)
if err != nil { if err != nil {
ctxLogger.Debug("Skipping default org sync, invalid ID for identity", "id", currentIdentity.ID, "namespace", namespace, "err", err) ctxLogger.Debug("Skipping default org sync, invalid ID for identity", "id", currentIdentity.ID, "namespace", namespace, "err", err)
return nil return
} }
hasAssignedToOrg, err := s.validateUsingOrg(ctx, userID, s.cfg.LoginDefaultOrgId) hasAssignedToOrg, err := s.validateUsingOrg(ctx, userID, s.cfg.LoginDefaultOrgId)
if err != nil { if err != nil {
ctxLogger.Error("Skipping default org sync, failed to validate user's organizations", "id", currentIdentity.ID, "err", err) ctxLogger.Error("Skipping default org sync, failed to validate user's organizations", "id", currentIdentity.ID, "err", err)
return nil return
} }
if !hasAssignedToOrg { if !hasAssignedToOrg {
ctxLogger.Debug("Skipping default org sync, user is not assigned to org", "id", currentIdentity.ID, "org", s.cfg.LoginDefaultOrgId) ctxLogger.Debug("Skipping default org sync, user is not assigned to org", "id", currentIdentity.ID, "org", s.cfg.LoginDefaultOrgId)
return nil return
} }
cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: s.cfg.LoginDefaultOrgId} cmd := user.SetUsingOrgCommand{UserID: userID, OrgID: s.cfg.LoginDefaultOrgId}
if err := s.userService.SetUsingOrg(ctx, &cmd); err != nil { if svcErr := s.userService.SetUsingOrg(ctx, &cmd); svcErr != nil {
ctxLogger.Error("Failed to set default org", "id", currentIdentity.ID, "err", err) ctxLogger.Error("Failed to set default org", "id", currentIdentity.ID, "err", svcErr)
return err
} }
return nil
} }
func (s *OrgSync) validateUsingOrg(ctx context.Context, userID int64, orgID int64) (bool, error) { func (s *OrgSync) validateUsingOrg(ctx context.Context, userID int64, orgID int64) (bool, error) {

View File

@ -134,8 +134,7 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting int64 defaultOrgSetting int64
identity *authn.Identity identity *authn.Identity
setupMock func(*usertest.MockService, *orgtest.FakeOrgService) setupMock func(*usertest.MockService, *orgtest.FakeOrgService)
inputErr error
wantErr bool
}{ }{
{ {
name: "should set default org", name: "should set default org",
@ -157,6 +156,12 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
defaultOrgSetting: -1, defaultOrgSetting: -1,
identity: nil, identity: nil,
}, },
{
name: "should skip setting the default org when input err is not nil",
defaultOrgSetting: 2,
identity: &authn.Identity{ID: "user:1"},
inputErr: fmt.Errorf("error"),
},
{ {
name: "should skip setting the default org when identity is not a user", name: "should skip setting the default org when identity is not a user",
defaultOrgSetting: 2, defaultOrgSetting: 2,
@ -181,13 +186,12 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
}, },
}, },
{ {
name: "should return error when the user org update was unsuccessful", name: "should skip the hook when the user org update was unsuccessful",
defaultOrgSetting: 2, defaultOrgSetting: 2,
identity: &authn.Identity{ID: "user:1"}, identity: &authn.Identity{ID: "user:1"},
setupMock: func(userService *usertest.MockService, orgService *orgtest.FakeOrgService) { setupMock: func(userService *usertest.MockService, orgService *orgtest.FakeOrgService) {
userService.On("SetUsingOrg", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) userService.On("SetUsingOrg", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
}, },
wantErr: true,
}, },
} }
for _, tt := range testCases { for _, tt := range testCases {
@ -214,9 +218,9 @@ func TestOrgSync_SetDefaultOrgHook(t *testing.T) {
cfg: cfg, cfg: cfg,
} }
if err := s.SetDefaultOrgHook(context.Background(), tt.identity, nil); (err != nil) != tt.wantErr { s.SetDefaultOrgHook(context.Background(), tt.identity, nil, tt.inputErr)
t.Errorf("OrgSync.SetDefaultOrgHook() error = %v, wantErr %v", err, tt.wantErr)
} userService.AssertExpectations(t)
}) })
} }
} }