mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: Remove result fields from login (#65136)
* remove result fields from login * fix tests * fix tests * another shadowing
This commit is contained in:
parent
3b37135b5b
commit
a38f230d37
@ -294,7 +294,7 @@ func (hs *HTTPServer) AdminDisableUser(c *contextmodel.ReqContext) response.Resp
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
|
||||
if _, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not disable external user", nil)
|
||||
}
|
||||
|
||||
@ -337,7 +337,7 @@ func (hs *HTTPServer) AdminEnableUser(c *contextmodel.ReqContext) response.Respo
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
|
||||
if _, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not enable external user", nil)
|
||||
}
|
||||
|
||||
|
@ -318,8 +318,8 @@ func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
|
||||
// If SAML is enabled and this is a SAML user use saml logout
|
||||
if hs.samlSingleLogoutEnabled() {
|
||||
getAuthQuery := loginservice.GetAuthInfoQuery{UserId: c.UserID}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
if getAuthQuery.Result.AuthModule == loginservice.SAMLAuthModule {
|
||||
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
if authInfo.AuthModule == loginservice.SAMLAuthModule {
|
||||
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
|
||||
return
|
||||
}
|
||||
|
@ -351,18 +351,19 @@ func (hs *HTTPServer) SyncUser(
|
||||
},
|
||||
}
|
||||
|
||||
if err := hs.Login.UpsertUser(ctx.Req.Context(), cmd); err != nil {
|
||||
upsertedUser, err := hs.Login.UpsertUser(ctx.Req.Context(), cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Do not expose disabled status,
|
||||
// just show incorrect user credentials error (see #17947)
|
||||
if cmd.Result.IsDisabled {
|
||||
oauthLogger.Warn("User is disabled", "user", cmd.Result.Login)
|
||||
if upsertedUser.IsDisabled {
|
||||
oauthLogger.Warn("User is disabled", "user", upsertedUser.Login)
|
||||
return nil, login.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return cmd.Result, nil
|
||||
return upsertedUser, nil
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) hashStatecode(code, seed string) string {
|
||||
|
@ -38,8 +38,8 @@ func (hs *HTTPServer) SendResetPasswordEmail(c *contextmodel.ReqContext) respons
|
||||
}
|
||||
|
||||
getAuthQuery := login.GetAuthInfoQuery{UserId: usr.ID}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authModule := getAuthQuery.Result.AuthModule
|
||||
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authModule := authInfo.AuthModule
|
||||
if authModule == login.LDAPAuthModule || authModule == login.AuthProxyAuthModule {
|
||||
return response.Error(401, "Not allowed to reset password for LDAP or Auth Proxy user", nil)
|
||||
}
|
||||
|
@ -63,8 +63,8 @@ func (hs *HTTPServer) getUserUserProfile(c *contextmodel.ReqContext, userID int6
|
||||
|
||||
getAuthQuery := login.GetAuthInfoQuery{UserId: userID}
|
||||
userProfile.AuthLabels = []string{}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authLabel := login.GetAuthProviderLabel(getAuthQuery.Result.AuthModule)
|
||||
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authLabel := login.GetAuthProviderLabel(authInfo.AuthModule)
|
||||
userProfile.AuthLabels = append(userProfile.AuthLabels, authLabel)
|
||||
userProfile.IsExternal = true
|
||||
userProfile.IsExternallySynced = login.IsExternallySynced(hs.Cfg, getAuthQuery.Result.AuthModule)
|
||||
@ -225,7 +225,7 @@ func (hs *HTTPServer) handleUpdateUser(ctx context.Context, cmd user.UpdateUserC
|
||||
func (hs *HTTPServer) isExternalUser(ctx context.Context, userID int64) (bool, error) {
|
||||
getAuthQuery := login.GetAuthInfoQuery{UserId: userID}
|
||||
var err error
|
||||
if err = hs.authInfoService.GetAuthInfo(ctx, &getAuthQuery); err == nil {
|
||||
if _, err = hs.authInfoService.GetAuthInfo(ctx, &getAuthQuery); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -434,8 +434,8 @@ func (hs *HTTPServer) ChangeUserPassword(c *contextmodel.ReqContext) response.Re
|
||||
}
|
||||
|
||||
getAuthQuery := login.GetAuthInfoQuery{UserId: usr.ID}
|
||||
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authModule := getAuthQuery.Result.AuthModule
|
||||
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
|
||||
authModule := authInfo.AuthModule
|
||||
if authModule == login.LDAPAuthModule || authModule == login.AuthProxyAuthModule {
|
||||
return response.Error(400, "Not allowed to reset password for LDAP or Auth Proxy user", nil)
|
||||
}
|
||||
|
@ -59,10 +59,6 @@ var loginUsingLDAP = func(ctx context.Context, query *login.LoginUserQuery,
|
||||
UserID: nil,
|
||||
},
|
||||
}
|
||||
if err = loginService.UpsertUser(ctx, upsert); err != nil {
|
||||
return true, err
|
||||
}
|
||||
query.User = upsert.Result
|
||||
|
||||
return true, nil
|
||||
query.User, err = loginService.UpsertUser(ctx, upsert)
|
||||
return true, err
|
||||
}
|
||||
|
@ -280,16 +280,16 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
|
||||
// Check auth info fist
|
||||
if identity.AuthID != "" && identity.AuthModule != "" {
|
||||
query := &login.GetAuthInfoQuery{AuthId: identity.AuthID, AuthModule: identity.AuthModule}
|
||||
errGetAuthInfo := s.authInfoService.GetAuthInfo(ctx, query)
|
||||
authInfo, errGetAuthInfo := s.authInfoService.GetAuthInfo(ctx, query)
|
||||
|
||||
if errGetAuthInfo != nil && !errors.Is(errGetAuthInfo, user.ErrUserNotFound) {
|
||||
return nil, nil, errGetAuthInfo
|
||||
}
|
||||
|
||||
if !errors.Is(errGetAuthInfo, user.ErrUserNotFound) {
|
||||
usr, errGetByID := s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: query.Result.UserId})
|
||||
usr, errGetByID := s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: authInfo.UserId})
|
||||
if errGetByID == nil {
|
||||
return usr, query.Result, nil
|
||||
return usr, authInfo, nil
|
||||
}
|
||||
|
||||
if !errors.Is(errGetByID, user.ErrUserNotFound) {
|
||||
@ -298,7 +298,7 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
|
||||
|
||||
// if the user connected to user auth does not exist try to clean it up
|
||||
if errors.Is(errGetByID, user.ErrUserNotFound) {
|
||||
if err := s.authInfoService.DeleteUserAuthInfo(ctx, query.Result.UserId); err != nil {
|
||||
if err := s.authInfoService.DeleteUserAuthInfo(ctx, authInfo.UserId); err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to clean up user auth", "error", err, "auth_module", identity.AuthModule, "auth_id", identity.AuthID)
|
||||
}
|
||||
}
|
||||
@ -316,11 +316,10 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
|
||||
// so we need to find the user first then check for the userAuth connection by module and userID
|
||||
if identity.AuthModule == login.GenericOAuthModule {
|
||||
query := &login.GetAuthInfoQuery{AuthModule: identity.AuthModule, UserId: usr.ID}
|
||||
err := s.authInfoService.GetAuthInfo(ctx, query)
|
||||
userAuth, err = s.authInfoService.GetAuthInfo(ctx, query)
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
userAuth = query.Result
|
||||
}
|
||||
|
||||
return usr, userAuth, nil
|
||||
|
@ -127,7 +127,7 @@ func (h *ContextHandler) initContextWithJWT(ctx *contextmodel.ReqContext, orgId
|
||||
Email: &query.Email,
|
||||
},
|
||||
}
|
||||
if err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil {
|
||||
if _, err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil {
|
||||
ctx.Logger.Error("Failed to upsert JWT user", "error", err)
|
||||
return false
|
||||
}
|
||||
|
@ -241,11 +241,12 @@ func (auth *AuthProxy) LoginViaLDAP(reqCtx *contextmodel.ReqContext) (int64, err
|
||||
UserID: nil,
|
||||
},
|
||||
}
|
||||
if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
|
||||
u, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return upsert.Result.ID, nil
|
||||
return u.ID, nil
|
||||
}
|
||||
|
||||
// loginViaHeader logs in user from the header only
|
||||
@ -304,12 +305,12 @@ func (auth *AuthProxy) loginViaHeader(reqCtx *contextmodel.ReqContext) (int64, e
|
||||
},
|
||||
}
|
||||
|
||||
err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
|
||||
result, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return upsert.Result.ID, nil
|
||||
return result.ID, nil
|
||||
}
|
||||
|
||||
// getDecodedHeader gets decoded value of a header with given headerName
|
||||
|
@ -193,7 +193,7 @@ func (s *Service) PostSyncUserWithLDAP(c *contextmodel.ReqContext) response.Resp
|
||||
}
|
||||
|
||||
authModuleQuery := &login.GetAuthInfoQuery{UserId: usr.ID, AuthModule: login.LDAPAuthModule}
|
||||
if err := s.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP
|
||||
if _, err := s.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return response.Error(404, user.ErrUserNotFound.Error(), nil)
|
||||
}
|
||||
@ -239,7 +239,7 @@ func (s *Service) PostSyncUserWithLDAP(c *contextmodel.ReqContext) response.Resp
|
||||
},
|
||||
}
|
||||
|
||||
err = s.loginService.UpsertUser(c.Req.Context(), upsertCmd)
|
||||
_, err = s.loginService.UpsertUser(c.Req.Context(), upsertCmd)
|
||||
if err != nil {
|
||||
return response.Error(http.StatusInternalServerError, "Failed to update the user", err)
|
||||
}
|
||||
|
@ -9,17 +9,17 @@ import (
|
||||
|
||||
type AuthInfoService interface {
|
||||
LookupAndUpdate(ctx context.Context, query *GetUserByAuthInfoQuery) (*user.User, error)
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) error
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
|
||||
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
|
||||
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) error
|
||||
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) (*ExternalUserInfo, error)
|
||||
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
|
||||
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error
|
||||
DeleteUserAuthInfo(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) error
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) error
|
||||
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) (*ExternalUserInfo, error)
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
|
||||
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
|
||||
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
|
||||
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error
|
||||
|
@ -34,35 +34,36 @@ func ProvideAuthInfoStore(sqlStore db.DB, secretsService secrets.Service, userSe
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
|
||||
func (s *AuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
|
||||
userQuery := user.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
||||
usr, err := s.userService.GetByLogin(ctx, &userQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.ID}
|
||||
if err := s.GetAuthInfo(ctx, authInfoQuery); err != nil {
|
||||
return err
|
||||
authInfo, err := s.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query.Result = &login.ExternalUserInfo{
|
||||
result := &login.ExternalUserInfo{
|
||||
UserId: usr.ID,
|
||||
Login: usr.Login,
|
||||
Email: usr.Email,
|
||||
Name: usr.Name,
|
||||
IsDisabled: usr.IsDisabled,
|
||||
AuthModule: authInfoQuery.Result.AuthModule,
|
||||
AuthId: authInfoQuery.Result.AuthId,
|
||||
AuthModule: authInfo.AuthModule,
|
||||
AuthId: authInfo.AuthId,
|
||||
}
|
||||
return nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAuthInfo returns the auth info for a user
|
||||
// It will return the latest auth info for a user
|
||||
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
|
||||
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
if query.UserId == 0 && query.AuthId == "" {
|
||||
return user.ErrUserNotFound
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
userAuth := &login.UserAuth{
|
||||
@ -79,36 +80,35 @@ func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInf
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !has {
|
||||
return user.ErrUserNotFound
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
userAuth.OAuthAccessToken = secretAccessToken
|
||||
userAuth.OAuthRefreshToken = secretRefreshToken
|
||||
userAuth.OAuthTokenType = secretTokenType
|
||||
userAuth.OAuthIdToken = secretIdToken
|
||||
|
||||
query.Result = userAuth
|
||||
return nil
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
|
||||
|
@ -38,7 +38,7 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
authQuery.AuthId = query.AuthId
|
||||
|
||||
err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
if !errors.Is(err, user.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
@ -47,21 +47,21 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserLookupParams.UserID != nil &&
|
||||
*query.UserLookupParams.UserID != 0 &&
|
||||
*query.UserLookupParams.UserID != authQuery.Result.UserId {
|
||||
*query.UserLookupParams.UserID != userAuth.UserId {
|
||||
if err := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
UserAuth: userAuth,
|
||||
}); err != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
return false, nil, nil, user.ErrUserNotFound
|
||||
} else {
|
||||
usr, err := s.authInfoStore.GetUserById(ctx, authQuery.Result.UserId)
|
||||
usr, err := s.authInfoStore.GetUserById(ctx, userAuth.UserId)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// if the user has been deleted then remove the entry
|
||||
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
|
||||
UserAuth: authQuery.Result,
|
||||
UserAuth: userAuth,
|
||||
}); errDel != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", errDel)
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
return true, usr, authQuery.Result, nil
|
||||
return true, usr, userAuth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -121,12 +121,12 @@ func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule stri
|
||||
authQuery.AuthModule = authModule
|
||||
authQuery.AuthId = authId
|
||||
authQuery.UserId = userID
|
||||
err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return authQuery.Result, nil
|
||||
return userAuth, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@ -182,7 +182,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *login.GetUs
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
|
||||
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
return s.authInfoStore.GetAuthInfo(ctx, query)
|
||||
}
|
||||
|
||||
@ -201,7 +201,7 @@ func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfo
|
||||
return s.authInfoStore.SetAuthInfo(ctx, cmd)
|
||||
}
|
||||
|
||||
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
|
||||
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
|
||||
return s.authInfoStore.GetExternalUserInfoByLogin(ctx, query)
|
||||
}
|
||||
|
||||
|
@ -201,13 +201,13 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
err = srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err := srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, token.AccessToken, getAuthQuery.Result.OAuthAccessToken)
|
||||
require.Equal(t, token.RefreshToken, getAuthQuery.Result.OAuthRefreshToken)
|
||||
require.Equal(t, token.TokenType, getAuthQuery.Result.OAuthTokenType)
|
||||
require.Equal(t, idToken, getAuthQuery.Result.OAuthIdToken)
|
||||
require.Equal(t, token.AccessToken, authInfo.OAuthAccessToken)
|
||||
require.Equal(t, token.RefreshToken, authInfo.OAuthRefreshToken)
|
||||
require.Equal(t, token.TokenType, authInfo.OAuthTokenType)
|
||||
require.Equal(t, idToken, authInfo.OAuthIdToken)
|
||||
})
|
||||
|
||||
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
||||
@ -261,10 +261,10 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
|
||||
require.Equal(t, authInfo.AuthModule, "test2")
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1"}
|
||||
@ -277,10 +277,10 @@ func TestUserAuth(t *testing.T) {
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
|
||||
require.Equal(t, authInfo.AuthModule, "test1")
|
||||
})
|
||||
|
||||
t.Run("Keeps track of last used auth_module when not using oauth", func(t *testing.T) {
|
||||
@ -334,10 +334,10 @@ func TestUserAuth(t *testing.T) {
|
||||
}
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test2"
|
||||
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test2", getAuthQuery.Result.AuthModule)
|
||||
require.Equal(t, "test2", authInfo.AuthModule)
|
||||
|
||||
// Now reuse first auth module and make sure it's updated to the most recent
|
||||
database.GetTime = func() time.Time { return fixedTime }
|
||||
@ -357,12 +357,12 @@ func TestUserAuth(t *testing.T) {
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test1"
|
||||
authInfoStore.ExpectedOAuth.OAuthAccessToken = "access_token"
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test1", getAuthQuery.Result.AuthModule)
|
||||
require.Equal(t, "test1", authInfo.AuthModule)
|
||||
// make sure oauth info is not overwritten by update date
|
||||
require.Equal(t, "access_token", getAuthQuery.Result.OAuthAccessToken)
|
||||
require.Equal(t, "access_token", authInfo.OAuthAccessToken)
|
||||
|
||||
// Now reuse second auth module and make sure it's updated to the most recent
|
||||
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, 1) }
|
||||
@ -371,9 +371,9 @@ func TestUserAuth(t *testing.T) {
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test2"
|
||||
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test2", getAuthQuery.Result.AuthModule)
|
||||
require.Equal(t, "test2", authInfo.AuthModule)
|
||||
|
||||
// Ensure test 1 did not have its entry modified
|
||||
getAuthQueryUnchanged := &login.GetAuthInfoQuery{
|
||||
@ -382,9 +382,9 @@ func TestUserAuth(t *testing.T) {
|
||||
}
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test1"
|
||||
|
||||
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test1", getAuthQueryUnchanged.Result.AuthModule)
|
||||
require.Equal(t, "test1", authInfo.AuthModule)
|
||||
})
|
||||
|
||||
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
|
||||
@ -520,12 +520,11 @@ func newFakeAuthInfoStore() *FakeAuthInfoStore {
|
||||
return &FakeAuthInfoStore{}
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
|
||||
return f.ExpectedError
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
|
||||
return nil, f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
|
||||
query.Result = f.ExpectedOAuth
|
||||
return f.ExpectedError
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
return f.ExpectedOAuth, f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
|
@ -17,7 +17,7 @@ var (
|
||||
type TeamSyncFunc func(user *user.User, externalUser *ExternalUserInfo) error
|
||||
|
||||
type Service interface {
|
||||
UpsertUser(ctx context.Context, cmd *UpsertUserCommand) error
|
||||
UpsertUser(ctx context.Context, cmd *UpsertUserCommand) (*user.User, error)
|
||||
DisableExternalUser(ctx context.Context, username string) error
|
||||
SetTeamSyncFunc(TeamSyncFunc)
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ type Implementation struct {
|
||||
}
|
||||
|
||||
// UpsertUser updates an existing user, or if it doesn't exist, inserts a new one.
|
||||
func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
|
||||
func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (result *user.User, err error) {
|
||||
var logger log.Logger = logger
|
||||
if cmd.ReqContext != nil && cmd.ReqContext.Logger != nil {
|
||||
logger = cmd.ReqContext.Logger
|
||||
@ -58,12 +58,12 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
|
||||
})
|
||||
if errAuthLookup != nil {
|
||||
if !errors.Is(errAuthLookup, user.ErrUserNotFound) {
|
||||
return errAuthLookup
|
||||
return nil, errAuthLookup
|
||||
}
|
||||
|
||||
if !cmd.SignupAllowed {
|
||||
logger.Warn("Not allowing login, user not found in internal user database and allow signup = false", "authmode", extUser.AuthModule)
|
||||
return login.ErrSignupNotAllowed
|
||||
return nil, login.ErrSignupNotAllowed
|
||||
}
|
||||
|
||||
// quota check (FIXME: (jguer) this should be done in the user service)
|
||||
@ -73,67 +73,67 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
|
||||
limitReached, errLimit := ls.QuotaService.CheckQuotaReached(ctx, quota.TargetSrv(srv), nil)
|
||||
if errLimit != nil {
|
||||
logger.Warn("Error getting user quota.", "error", errLimit)
|
||||
return login.ErrGettingUserQuota
|
||||
return nil, login.ErrGettingUserQuota
|
||||
}
|
||||
if limitReached {
|
||||
return login.ErrUsersQuotaReached
|
||||
return nil, login.ErrUsersQuotaReached
|
||||
}
|
||||
}
|
||||
|
||||
result, errCreateUser := ls.userService.Create(ctx, &user.CreateUserCommand{
|
||||
createdUser, errCreateUser := ls.userService.Create(ctx, &user.CreateUserCommand{
|
||||
Login: extUser.Login,
|
||||
Email: extUser.Email,
|
||||
Name: extUser.Name,
|
||||
SkipOrgSetup: len(extUser.OrgRoles) > 0,
|
||||
})
|
||||
if errCreateUser != nil {
|
||||
return errCreateUser
|
||||
return nil, errCreateUser
|
||||
}
|
||||
|
||||
cmd.Result = &user.User{
|
||||
ID: result.ID,
|
||||
Version: result.Version,
|
||||
Email: result.Email,
|
||||
Name: result.Name,
|
||||
Login: result.Login,
|
||||
Password: result.Password,
|
||||
Salt: result.Salt,
|
||||
Rands: result.Rands,
|
||||
Company: result.Company,
|
||||
EmailVerified: result.EmailVerified,
|
||||
Theme: result.Theme,
|
||||
HelpFlags1: result.HelpFlags1,
|
||||
IsDisabled: result.IsDisabled,
|
||||
IsAdmin: result.IsAdmin,
|
||||
IsServiceAccount: result.IsServiceAccount,
|
||||
OrgID: result.OrgID,
|
||||
Created: result.Created,
|
||||
Updated: result.Updated,
|
||||
LastSeenAt: result.LastSeenAt,
|
||||
result = &user.User{
|
||||
ID: createdUser.ID,
|
||||
Version: createdUser.Version,
|
||||
Email: createdUser.Email,
|
||||
Name: createdUser.Name,
|
||||
Login: createdUser.Login,
|
||||
Password: createdUser.Password,
|
||||
Salt: createdUser.Salt,
|
||||
Rands: createdUser.Rands,
|
||||
Company: createdUser.Company,
|
||||
EmailVerified: createdUser.EmailVerified,
|
||||
Theme: createdUser.Theme,
|
||||
HelpFlags1: createdUser.HelpFlags1,
|
||||
IsDisabled: createdUser.IsDisabled,
|
||||
IsAdmin: createdUser.IsAdmin,
|
||||
IsServiceAccount: createdUser.IsServiceAccount,
|
||||
OrgID: createdUser.OrgID,
|
||||
Created: createdUser.Created,
|
||||
Updated: createdUser.Updated,
|
||||
LastSeenAt: createdUser.LastSeenAt,
|
||||
}
|
||||
|
||||
if extUser.AuthModule != "" {
|
||||
cmd2 := &login.SetAuthInfoCommand{
|
||||
UserId: cmd.Result.ID,
|
||||
UserId: result.ID,
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
OAuthToken: extUser.OAuthToken,
|
||||
}
|
||||
if errSetAuth := ls.AuthInfoService.SetAuthInfo(ctx, cmd2); errSetAuth != nil {
|
||||
return errSetAuth
|
||||
return nil, errSetAuth
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cmd.Result = usr
|
||||
result = usr
|
||||
|
||||
if errUserMod := ls.updateUser(ctx, cmd.Result, extUser); errUserMod != nil {
|
||||
return errUserMod
|
||||
if errUserMod := ls.updateUser(ctx, result, extUser); errUserMod != nil {
|
||||
return nil, errUserMod
|
||||
}
|
||||
|
||||
// Always persist the latest token at log-in
|
||||
if extUser.AuthModule != "" && extUser.OAuthToken != nil {
|
||||
if errAuthMod := ls.updateUserAuth(ctx, cmd.Result, extUser); errAuthMod != nil {
|
||||
return errAuthMod
|
||||
if errAuthMod := ls.updateUserAuth(ctx, result, extUser); errAuthMod != nil {
|
||||
return nil, errAuthMod
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,31 +141,31 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
|
||||
// Re-enable user when it found in LDAP
|
||||
if errDisableUser := ls.userService.Disable(ctx,
|
||||
&user.DisableUserCommand{
|
||||
UserID: cmd.Result.ID, IsDisabled: false}); errDisableUser != nil {
|
||||
return errDisableUser
|
||||
UserID: result.ID, IsDisabled: false}); errDisableUser != nil {
|
||||
return nil, errDisableUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errSyncRole := ls.syncOrgRoles(ctx, cmd.Result, extUser); errSyncRole != nil {
|
||||
return errSyncRole
|
||||
if errSyncRole := ls.syncOrgRoles(ctx, result, extUser); errSyncRole != nil {
|
||||
return nil, errSyncRole
|
||||
}
|
||||
|
||||
// Sync isGrafanaAdmin permission
|
||||
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
|
||||
if errPerms := ls.userService.UpdatePermissions(ctx, cmd.Result.ID, *extUser.IsGrafanaAdmin); errPerms != nil {
|
||||
return errPerms
|
||||
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != result.IsAdmin {
|
||||
if errPerms := ls.userService.UpdatePermissions(ctx, result.ID, *extUser.IsGrafanaAdmin); errPerms != nil {
|
||||
return nil, errPerms
|
||||
}
|
||||
}
|
||||
|
||||
// There are external providers where we want to completely skip team synchronization see - https://github.com/grafana/grafana/issues/62175
|
||||
if ls.TeamSync != nil && !extUser.SkipTeamSync {
|
||||
if errTeamSync := ls.TeamSync(cmd.Result, extUser); errTeamSync != nil {
|
||||
return errTeamSync
|
||||
if errTeamSync := ls.TeamSync(result, extUser); errTeamSync != nil {
|
||||
return nil, errTeamSync
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (ls *Implementation) DisableExternalUser(ctx context.Context, username string) error {
|
||||
@ -174,11 +174,11 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
|
||||
LoginOrEmail: username,
|
||||
}
|
||||
|
||||
if err := ls.AuthInfoService.GetExternalUserInfoByLogin(ctx, userQuery); err != nil {
|
||||
userInfo, err := ls.AuthInfoService.GetExternalUserInfoByLogin(ctx, userQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userInfo := userQuery.Result
|
||||
if userInfo.IsDisabled {
|
||||
return nil
|
||||
}
|
||||
@ -186,12 +186,12 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
|
||||
logger.Debug(
|
||||
"Disabling external user",
|
||||
"user",
|
||||
userQuery.Result.Login,
|
||||
userInfo.Login,
|
||||
)
|
||||
|
||||
// Mark user as disabled in grafana db
|
||||
disableUserCmd := &user.DisableUserCommand{
|
||||
UserID: userQuery.Result.UserId,
|
||||
UserID: userInfo.UserId,
|
||||
IsDisabled: true,
|
||||
}
|
||||
|
||||
@ -199,7 +199,7 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
|
||||
logger.Debug(
|
||||
"Error disabling external user",
|
||||
"user",
|
||||
userQuery.Result.Login,
|
||||
userInfo.Login,
|
||||
"message",
|
||||
err.Error(),
|
||||
)
|
||||
|
@ -14,13 +14,11 @@ type LoginServiceMock struct {
|
||||
ExpectedError error
|
||||
}
|
||||
|
||||
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
|
||||
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (*user.User, error) {
|
||||
if s.ExpectedUserFunc != nil {
|
||||
cmd.Result = s.ExpectedUserFunc(cmd)
|
||||
return s.ExpectedError
|
||||
return s.ExpectedUserFunc(cmd), s.ExpectedError
|
||||
}
|
||||
cmd.Result = s.ExpectedUser
|
||||
return s.ExpectedError
|
||||
return s.ExpectedUser, s.ExpectedError
|
||||
}
|
||||
|
||||
func (s LoginServiceMock) DisableExternalUser(ctx context.Context, username string) error {
|
||||
|
@ -85,7 +85,7 @@ func Test_teamSync(t *testing.T) {
|
||||
var actualExternalUser *login.ExternalUserInfo
|
||||
|
||||
t.Run("login.TeamSync should not be called when nil", func(t *testing.T) {
|
||||
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
require.Nil(t, err)
|
||||
assert.Nil(t, actualUser)
|
||||
assert.Nil(t, actualExternalUser)
|
||||
@ -97,7 +97,7 @@ func Test_teamSync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
loginsvc.TeamSync = teamSyncFunc
|
||||
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, actualUser, expectedUser)
|
||||
assert.Equal(t, actualExternalUser, upsertCmd.ExternalUser)
|
||||
@ -120,7 +120,7 @@ func Test_teamSync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
loginsvc.TeamSync = teamSyncFunc
|
||||
err := loginsvc.UpsertUser(context.Background(), upsertCmdSkipTeamSync)
|
||||
_, err := loginsvc.UpsertUser(context.Background(), upsertCmdSkipTeamSync)
|
||||
require.Nil(t, err)
|
||||
assert.Nil(t, actualUser)
|
||||
assert.Nil(t, actualExternalUser)
|
||||
@ -131,7 +131,7 @@ func Test_teamSync(t *testing.T) {
|
||||
return errors.New("teamsync test error")
|
||||
}
|
||||
loginsvc.TeamSync = teamSyncFunc
|
||||
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
@ -154,7 +154,7 @@ func TestUpsertUser_crashOnLog_issue62538(t *testing.T) {
|
||||
|
||||
var err error
|
||||
require.NotPanics(t, func() {
|
||||
err = loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
_, err = loginsvc.UpsertUser(context.Background(), upsertCmd)
|
||||
})
|
||||
require.ErrorIs(t, err, login.ErrSignupNotAllowed)
|
||||
}
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
|
||||
type LoginServiceFake struct{}
|
||||
|
||||
func (l *LoginServiceFake) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
|
||||
return nil
|
||||
func (l *LoginServiceFake) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (*user.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (l *LoginServiceFake) DisableExternalUser(ctx context.Context, username string) error {
|
||||
return nil
|
||||
@ -39,10 +39,9 @@ func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *login.
|
||||
return a.ExpectedUser, a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
|
||||
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
a.LatestUserID = query.UserId
|
||||
query.Result = a.ExpectedUserAuth
|
||||
return a.ExpectedError
|
||||
return a.ExpectedUserAuth, a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
|
||||
@ -65,9 +64,8 @@ func (a *AuthInfoServiceFake) UpdateAuthInfo(ctx context.Context, cmd *login.Upd
|
||||
return a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
|
||||
query.Result = a.ExpectedExternalUser
|
||||
return a.ExpectedError
|
||||
func (a *AuthInfoServiceFake) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
|
||||
return a.ExpectedExternalUser, a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
|
||||
|
@ -91,8 +91,6 @@ type UpsertUserCommand struct {
|
||||
ExternalUser *ExternalUserInfo
|
||||
UserLookupParams
|
||||
SignupAllowed bool
|
||||
|
||||
Result *user.User
|
||||
}
|
||||
|
||||
type SetAuthInfoCommand struct {
|
||||
@ -141,16 +139,12 @@ type UserLookupParams struct {
|
||||
|
||||
type GetExternalUserInfoByLoginQuery struct {
|
||||
LoginOrEmail string
|
||||
|
||||
Result *ExternalUserInfo
|
||||
}
|
||||
|
||||
type GetAuthInfoQuery struct {
|
||||
UserId int64
|
||||
AuthModule string
|
||||
AuthId string
|
||||
|
||||
Result *UserAuth
|
||||
}
|
||||
|
||||
type GetUserLabelsQuery struct {
|
||||
|
@ -59,7 +59,8 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
|
||||
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
|
||||
@ -69,10 +70,10 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
|
||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return buildOAuthTokenFromAuthInfo(authInfoQuery.Result)
|
||||
return buildOAuthTokenFromAuthInfo(authInfo)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -94,7 +95,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*l
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
|
||||
err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
@ -103,10 +104,10 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*l
|
||||
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
return nil, false, err
|
||||
}
|
||||
if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") {
|
||||
if !strings.Contains(authInfo.AuthModule, "oauth") {
|
||||
return nil, false, nil
|
||||
}
|
||||
return authInfoQuery.Result, true, nil
|
||||
return authInfo, true, nil
|
||||
}
|
||||
|
||||
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||
|
@ -119,12 +119,11 @@ func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{}
|
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
// User's token data had not been updated
|
||||
resultUsr := authInfoQuery.Result
|
||||
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
|
||||
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
|
||||
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
|
||||
@ -193,15 +192,15 @@ func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{}
|
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType)
|
||||
assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken)
|
||||
assert.Equal(t, authInfo.OAuthExpiry, newToken.Expiry)
|
||||
assert.Equal(t, authInfo.OAuthRefreshToken, newToken.RefreshToken)
|
||||
assert.Equal(t, authInfo.OAuthTokenType, newToken.TokenType)
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
|
||||
@ -318,13 +317,12 @@ type FakeAuthInfoStore struct {
|
||||
ExpectedLoginStats login.LoginStats
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
|
||||
return f.ExpectedError
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
|
||||
return nil, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
|
||||
query.Result = f.ExpectedOAuth
|
||||
return f.ExpectedError
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
return f.ExpectedOAuth, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
|
Loading…
Reference in New Issue
Block a user