Chore: Remove result fields from login (#65136)

* remove result fields from login

* fix tests

* fix tests

* another shadowing
This commit is contained in:
Serge Zaitsev 2023-03-28 20:32:21 +02:00 committed by GitHub
parent 3b37135b5b
commit a38f230d37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 166 additions and 181 deletions

View File

@ -294,7 +294,7 @@ func (hs *HTTPServer) AdminDisableUser(c *contextmodel.ReqContext) response.Resp
// External users shouldn't be disabled from API
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
if _, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
return response.Error(500, "Could not disable external user", nil)
}
@ -337,7 +337,7 @@ func (hs *HTTPServer) AdminEnableUser(c *contextmodel.ReqContext) response.Respo
// External users shouldn't be disabled from API
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
if _, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authInfoQuery); !errors.Is(err, user.ErrUserNotFound) {
return response.Error(500, "Could not enable external user", nil)
}

View File

@ -318,8 +318,8 @@ func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
// If SAML is enabled and this is a SAML user use saml logout
if hs.samlSingleLogoutEnabled() {
getAuthQuery := loginservice.GetAuthInfoQuery{UserId: c.UserID}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
if getAuthQuery.Result.AuthModule == loginservice.SAMLAuthModule {
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
if authInfo.AuthModule == loginservice.SAMLAuthModule {
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
return
}

View File

@ -351,18 +351,19 @@ func (hs *HTTPServer) SyncUser(
},
}
if err := hs.Login.UpsertUser(ctx.Req.Context(), cmd); err != nil {
upsertedUser, err := hs.Login.UpsertUser(ctx.Req.Context(), cmd)
if err != nil {
return nil, err
}
// Do not expose disabled status,
// just show incorrect user credentials error (see #17947)
if cmd.Result.IsDisabled {
oauthLogger.Warn("User is disabled", "user", cmd.Result.Login)
if upsertedUser.IsDisabled {
oauthLogger.Warn("User is disabled", "user", upsertedUser.Login)
return nil, login.ErrInvalidCredentials
}
return cmd.Result, nil
return upsertedUser, nil
}
func (hs *HTTPServer) hashStatecode(code, seed string) string {

View File

@ -38,8 +38,8 @@ func (hs *HTTPServer) SendResetPasswordEmail(c *contextmodel.ReqContext) respons
}
getAuthQuery := login.GetAuthInfoQuery{UserId: usr.ID}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authModule := getAuthQuery.Result.AuthModule
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authModule := authInfo.AuthModule
if authModule == login.LDAPAuthModule || authModule == login.AuthProxyAuthModule {
return response.Error(401, "Not allowed to reset password for LDAP or Auth Proxy user", nil)
}

View File

@ -63,8 +63,8 @@ func (hs *HTTPServer) getUserUserProfile(c *contextmodel.ReqContext, userID int6
getAuthQuery := login.GetAuthInfoQuery{UserId: userID}
userProfile.AuthLabels = []string{}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authLabel := login.GetAuthProviderLabel(getAuthQuery.Result.AuthModule)
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authLabel := login.GetAuthProviderLabel(authInfo.AuthModule)
userProfile.AuthLabels = append(userProfile.AuthLabels, authLabel)
userProfile.IsExternal = true
userProfile.IsExternallySynced = login.IsExternallySynced(hs.Cfg, getAuthQuery.Result.AuthModule)
@ -225,7 +225,7 @@ func (hs *HTTPServer) handleUpdateUser(ctx context.Context, cmd user.UpdateUserC
func (hs *HTTPServer) isExternalUser(ctx context.Context, userID int64) (bool, error) {
getAuthQuery := login.GetAuthInfoQuery{UserId: userID}
var err error
if err = hs.authInfoService.GetAuthInfo(ctx, &getAuthQuery); err == nil {
if _, err = hs.authInfoService.GetAuthInfo(ctx, &getAuthQuery); err == nil {
return true, nil
}
@ -434,8 +434,8 @@ func (hs *HTTPServer) ChangeUserPassword(c *contextmodel.ReqContext) response.Re
}
getAuthQuery := login.GetAuthInfoQuery{UserId: usr.ID}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authModule := getAuthQuery.Result.AuthModule
if authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil {
authModule := authInfo.AuthModule
if authModule == login.LDAPAuthModule || authModule == login.AuthProxyAuthModule {
return response.Error(400, "Not allowed to reset password for LDAP or Auth Proxy user", nil)
}

View File

@ -59,10 +59,6 @@ var loginUsingLDAP = func(ctx context.Context, query *login.LoginUserQuery,
UserID: nil,
},
}
if err = loginService.UpsertUser(ctx, upsert); err != nil {
return true, err
}
query.User = upsert.Result
return true, nil
query.User, err = loginService.UpsertUser(ctx, upsert)
return true, err
}

View File

@ -280,16 +280,16 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
// Check auth info fist
if identity.AuthID != "" && identity.AuthModule != "" {
query := &login.GetAuthInfoQuery{AuthId: identity.AuthID, AuthModule: identity.AuthModule}
errGetAuthInfo := s.authInfoService.GetAuthInfo(ctx, query)
authInfo, errGetAuthInfo := s.authInfoService.GetAuthInfo(ctx, query)
if errGetAuthInfo != nil && !errors.Is(errGetAuthInfo, user.ErrUserNotFound) {
return nil, nil, errGetAuthInfo
}
if !errors.Is(errGetAuthInfo, user.ErrUserNotFound) {
usr, errGetByID := s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: query.Result.UserId})
usr, errGetByID := s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: authInfo.UserId})
if errGetByID == nil {
return usr, query.Result, nil
return usr, authInfo, nil
}
if !errors.Is(errGetByID, user.ErrUserNotFound) {
@ -298,7 +298,7 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
// if the user connected to user auth does not exist try to clean it up
if errors.Is(errGetByID, user.ErrUserNotFound) {
if err := s.authInfoService.DeleteUserAuthInfo(ctx, query.Result.UserId); err != nil {
if err := s.authInfoService.DeleteUserAuthInfo(ctx, authInfo.UserId); err != nil {
s.log.FromContext(ctx).Error("Failed to clean up user auth", "error", err, "auth_module", identity.AuthModule, "auth_id", identity.AuthID)
}
}
@ -316,11 +316,10 @@ func (s *UserSync) getUser(ctx context.Context, identity *authn.Identity) (*user
// so we need to find the user first then check for the userAuth connection by module and userID
if identity.AuthModule == login.GenericOAuthModule {
query := &login.GetAuthInfoQuery{AuthModule: identity.AuthModule, UserId: usr.ID}
err := s.authInfoService.GetAuthInfo(ctx, query)
userAuth, err = s.authInfoService.GetAuthInfo(ctx, query)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return nil, nil, err
}
userAuth = query.Result
}
return usr, userAuth, nil

View File

@ -127,7 +127,7 @@ func (h *ContextHandler) initContextWithJWT(ctx *contextmodel.ReqContext, orgId
Email: &query.Email,
},
}
if err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil {
if _, err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil {
ctx.Logger.Error("Failed to upsert JWT user", "error", err)
return false
}

View File

@ -241,11 +241,12 @@ func (auth *AuthProxy) LoginViaLDAP(reqCtx *contextmodel.ReqContext) (int64, err
UserID: nil,
},
}
if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
u, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
if err != nil {
return 0, err
}
return upsert.Result.ID, nil
return u.ID, nil
}
// loginViaHeader logs in user from the header only
@ -304,12 +305,12 @@ func (auth *AuthProxy) loginViaHeader(reqCtx *contextmodel.ReqContext) (int64, e
},
}
err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
result, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
if err != nil {
return 0, err
}
return upsert.Result.ID, nil
return result.ID, nil
}
// getDecodedHeader gets decoded value of a header with given headerName

View File

@ -193,7 +193,7 @@ func (s *Service) PostSyncUserWithLDAP(c *contextmodel.ReqContext) response.Resp
}
authModuleQuery := &login.GetAuthInfoQuery{UserId: usr.ID, AuthModule: login.LDAPAuthModule}
if err := s.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP
if _, err := s.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP
if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, user.ErrUserNotFound.Error(), nil)
}
@ -239,7 +239,7 @@ func (s *Service) PostSyncUserWithLDAP(c *contextmodel.ReqContext) response.Resp
},
}
err = s.loginService.UpsertUser(c.Req.Context(), upsertCmd)
_, err = s.loginService.UpsertUser(c.Req.Context(), upsertCmd)
if err != nil {
return response.Error(http.StatusInternalServerError, "Failed to update the user", err)
}

View File

@ -9,17 +9,17 @@ import (
type AuthInfoService interface {
LookupAndUpdate(ctx context.Context, query *GetUserByAuthInfoQuery) (*user.User, error)
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) error
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) error
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) (*ExternalUserInfo, error)
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error
DeleteUserAuthInfo(ctx context.Context, userID int64) error
}
type Store interface {
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) error
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) error
GetExternalUserInfoByLogin(ctx context.Context, query *GetExternalUserInfoByLoginQuery) (*ExternalUserInfo, error)
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error

View File

@ -34,35 +34,36 @@ func ProvideAuthInfoStore(sqlStore db.DB, secretsService secrets.Service, userSe
return store
}
func (s *AuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
func (s *AuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
userQuery := user.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
usr, err := s.userService.GetByLogin(ctx, &userQuery)
if err != nil {
return err
return nil, err
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.ID}
if err := s.GetAuthInfo(ctx, authInfoQuery); err != nil {
return err
authInfo, err := s.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
return nil, err
}
query.Result = &login.ExternalUserInfo{
result := &login.ExternalUserInfo{
UserId: usr.ID,
Login: usr.Login,
Email: usr.Email,
Name: usr.Name,
IsDisabled: usr.IsDisabled,
AuthModule: authInfoQuery.Result.AuthModule,
AuthId: authInfoQuery.Result.AuthId,
AuthModule: authInfo.AuthModule,
AuthId: authInfo.AuthId,
}
return nil
return result, nil
}
// GetAuthInfo returns the auth info for a user
// It will return the latest auth info for a user
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
if query.UserId == 0 && query.AuthId == "" {
return user.ErrUserNotFound
return nil, user.ErrUserNotFound
}
userAuth := &login.UserAuth{
@ -79,36 +80,35 @@ func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInf
return err
})
if err != nil {
return err
return nil, err
}
if !has {
return user.ErrUserNotFound
return nil, user.ErrUserNotFound
}
secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken)
if err != nil {
return err
return nil, err
}
secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken)
if err != nil {
return err
return nil, err
}
secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType)
if err != nil {
return err
return nil, err
}
secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
if err != nil {
return err
return nil, err
}
userAuth.OAuthAccessToken = secretAccessToken
userAuth.OAuthRefreshToken = secretRefreshToken
userAuth.OAuthTokenType = secretTokenType
userAuth.OAuthIdToken = secretIdToken
query.Result = userAuth
return nil
return userAuth, nil
}
func (s *AuthInfoStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {

View File

@ -38,7 +38,7 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
authQuery.AuthModule = query.AuthModule
authQuery.AuthId = query.AuthId
err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
if !errors.Is(err, user.ErrUserNotFound) {
if err != nil {
return false, nil, nil, err
@ -47,21 +47,21 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
// if user id was specified and doesn't match the user_auth entry, remove it
if query.UserLookupParams.UserID != nil &&
*query.UserLookupParams.UserID != 0 &&
*query.UserLookupParams.UserID != authQuery.Result.UserId {
*query.UserLookupParams.UserID != userAuth.UserId {
if err := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
UserAuth: userAuth,
}); err != nil {
s.logger.Error("Error removing user_auth entry", "error", err)
}
return false, nil, nil, user.ErrUserNotFound
} else {
usr, err := s.authInfoStore.GetUserById(ctx, authQuery.Result.UserId)
usr, err := s.authInfoStore.GetUserById(ctx, userAuth.UserId)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// if the user has been deleted then remove the entry
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
UserAuth: userAuth,
}); errDel != nil {
s.logger.Error("Error removing user_auth entry", "error", errDel)
}
@ -72,7 +72,7 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserB
return false, nil, nil, err
}
return true, usr, authQuery.Result, nil
return true, usr, userAuth, nil
}
}
}
@ -121,12 +121,12 @@ func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule stri
authQuery.AuthModule = authModule
authQuery.AuthId = authId
authQuery.UserId = userID
err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
if err != nil {
return nil, err
}
return authQuery.Result, nil
return userAuth, nil
}
return nil, nil
}
@ -182,7 +182,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *login.GetUs
return usr, nil
}
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return s.authInfoStore.GetAuthInfo(ctx, query)
}
@ -201,7 +201,7 @@ func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfo
return s.authInfoStore.SetAuthInfo(ctx, cmd)
}
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
return s.authInfoStore.GetExternalUserInfoByLogin(ctx, query)
}

View File

@ -201,13 +201,13 @@ func TestUserAuth(t *testing.T) {
UserId: user.ID,
}
err = srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err := srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, token.AccessToken, getAuthQuery.Result.OAuthAccessToken)
require.Equal(t, token.RefreshToken, getAuthQuery.Result.OAuthRefreshToken)
require.Equal(t, token.TokenType, getAuthQuery.Result.OAuthTokenType)
require.Equal(t, idToken, getAuthQuery.Result.OAuthIdToken)
require.Equal(t, token.AccessToken, authInfo.OAuthAccessToken)
require.Equal(t, token.RefreshToken, authInfo.OAuthRefreshToken)
require.Equal(t, token.TokenType, authInfo.OAuthTokenType)
require.Equal(t, idToken, authInfo.OAuthIdToken)
})
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
@ -261,10 +261,10 @@ func TestUserAuth(t *testing.T) {
UserId: user.ID,
}
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
require.Equal(t, authInfo.AuthModule, "test2")
// "log in" again with the first auth module
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1"}
@ -277,10 +277,10 @@ func TestUserAuth(t *testing.T) {
UserId: user.ID,
}
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
require.Equal(t, authInfo.AuthModule, "test1")
})
t.Run("Keeps track of last used auth_module when not using oauth", func(t *testing.T) {
@ -334,10 +334,10 @@ func TestUserAuth(t *testing.T) {
}
authInfoStore.ExpectedOAuth.AuthModule = "test2"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test2", getAuthQuery.Result.AuthModule)
require.Equal(t, "test2", authInfo.AuthModule)
// Now reuse first auth module and make sure it's updated to the most recent
database.GetTime = func() time.Time { return fixedTime }
@ -357,12 +357,12 @@ func TestUserAuth(t *testing.T) {
require.Equal(t, user.Login, userlogin)
authInfoStore.ExpectedOAuth.AuthModule = "test1"
authInfoStore.ExpectedOAuth.OAuthAccessToken = "access_token"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test1", getAuthQuery.Result.AuthModule)
require.Equal(t, "test1", authInfo.AuthModule)
// make sure oauth info is not overwritten by update date
require.Equal(t, "access_token", getAuthQuery.Result.OAuthAccessToken)
require.Equal(t, "access_token", authInfo.OAuthAccessToken)
// Now reuse second auth module and make sure it's updated to the most recent
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, 1) }
@ -371,9 +371,9 @@ func TestUserAuth(t *testing.T) {
require.Equal(t, user.Login, userlogin)
authInfoStore.ExpectedOAuth.AuthModule = "test2"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test2", getAuthQuery.Result.AuthModule)
require.Equal(t, "test2", authInfo.AuthModule)
// Ensure test 1 did not have its entry modified
getAuthQueryUnchanged := &login.GetAuthInfoQuery{
@ -382,9 +382,9 @@ func TestUserAuth(t *testing.T) {
}
authInfoStore.ExpectedOAuth.AuthModule = "test1"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
require.Nil(t, err)
require.Equal(t, "test1", getAuthQueryUnchanged.Result.AuthModule)
require.Equal(t, "test1", authInfo.AuthModule)
})
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
@ -520,12 +520,11 @@ func newFakeAuthInfoStore() *FakeAuthInfoStore {
return &FakeAuthInfoStore{}
}
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
return f.ExpectedError
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
return nil, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
query.Result = f.ExpectedOAuth
return f.ExpectedError
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return f.ExpectedOAuth, f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
return f.ExpectedError

View File

@ -17,7 +17,7 @@ var (
type TeamSyncFunc func(user *user.User, externalUser *ExternalUserInfo) error
type Service interface {
UpsertUser(ctx context.Context, cmd *UpsertUserCommand) error
UpsertUser(ctx context.Context, cmd *UpsertUserCommand) (*user.User, error)
DisableExternalUser(ctx context.Context, username string) error
SetTeamSyncFunc(TeamSyncFunc)
}

View File

@ -43,7 +43,7 @@ type Implementation struct {
}
// UpsertUser updates an existing user, or if it doesn't exist, inserts a new one.
func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (result *user.User, err error) {
var logger log.Logger = logger
if cmd.ReqContext != nil && cmd.ReqContext.Logger != nil {
logger = cmd.ReqContext.Logger
@ -58,12 +58,12 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
})
if errAuthLookup != nil {
if !errors.Is(errAuthLookup, user.ErrUserNotFound) {
return errAuthLookup
return nil, errAuthLookup
}
if !cmd.SignupAllowed {
logger.Warn("Not allowing login, user not found in internal user database and allow signup = false", "authmode", extUser.AuthModule)
return login.ErrSignupNotAllowed
return nil, login.ErrSignupNotAllowed
}
// quota check (FIXME: (jguer) this should be done in the user service)
@ -73,67 +73,67 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
limitReached, errLimit := ls.QuotaService.CheckQuotaReached(ctx, quota.TargetSrv(srv), nil)
if errLimit != nil {
logger.Warn("Error getting user quota.", "error", errLimit)
return login.ErrGettingUserQuota
return nil, login.ErrGettingUserQuota
}
if limitReached {
return login.ErrUsersQuotaReached
return nil, login.ErrUsersQuotaReached
}
}
result, errCreateUser := ls.userService.Create(ctx, &user.CreateUserCommand{
createdUser, errCreateUser := ls.userService.Create(ctx, &user.CreateUserCommand{
Login: extUser.Login,
Email: extUser.Email,
Name: extUser.Name,
SkipOrgSetup: len(extUser.OrgRoles) > 0,
})
if errCreateUser != nil {
return errCreateUser
return nil, errCreateUser
}
cmd.Result = &user.User{
ID: result.ID,
Version: result.Version,
Email: result.Email,
Name: result.Name,
Login: result.Login,
Password: result.Password,
Salt: result.Salt,
Rands: result.Rands,
Company: result.Company,
EmailVerified: result.EmailVerified,
Theme: result.Theme,
HelpFlags1: result.HelpFlags1,
IsDisabled: result.IsDisabled,
IsAdmin: result.IsAdmin,
IsServiceAccount: result.IsServiceAccount,
OrgID: result.OrgID,
Created: result.Created,
Updated: result.Updated,
LastSeenAt: result.LastSeenAt,
result = &user.User{
ID: createdUser.ID,
Version: createdUser.Version,
Email: createdUser.Email,
Name: createdUser.Name,
Login: createdUser.Login,
Password: createdUser.Password,
Salt: createdUser.Salt,
Rands: createdUser.Rands,
Company: createdUser.Company,
EmailVerified: createdUser.EmailVerified,
Theme: createdUser.Theme,
HelpFlags1: createdUser.HelpFlags1,
IsDisabled: createdUser.IsDisabled,
IsAdmin: createdUser.IsAdmin,
IsServiceAccount: createdUser.IsServiceAccount,
OrgID: createdUser.OrgID,
Created: createdUser.Created,
Updated: createdUser.Updated,
LastSeenAt: createdUser.LastSeenAt,
}
if extUser.AuthModule != "" {
cmd2 := &login.SetAuthInfoCommand{
UserId: cmd.Result.ID,
UserId: result.ID,
AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId,
OAuthToken: extUser.OAuthToken,
}
if errSetAuth := ls.AuthInfoService.SetAuthInfo(ctx, cmd2); errSetAuth != nil {
return errSetAuth
return nil, errSetAuth
}
}
} else {
cmd.Result = usr
result = usr
if errUserMod := ls.updateUser(ctx, cmd.Result, extUser); errUserMod != nil {
return errUserMod
if errUserMod := ls.updateUser(ctx, result, extUser); errUserMod != nil {
return nil, errUserMod
}
// Always persist the latest token at log-in
if extUser.AuthModule != "" && extUser.OAuthToken != nil {
if errAuthMod := ls.updateUserAuth(ctx, cmd.Result, extUser); errAuthMod != nil {
return errAuthMod
if errAuthMod := ls.updateUserAuth(ctx, result, extUser); errAuthMod != nil {
return nil, errAuthMod
}
}
@ -141,31 +141,31 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *login.UpsertUserC
// Re-enable user when it found in LDAP
if errDisableUser := ls.userService.Disable(ctx,
&user.DisableUserCommand{
UserID: cmd.Result.ID, IsDisabled: false}); errDisableUser != nil {
return errDisableUser
UserID: result.ID, IsDisabled: false}); errDisableUser != nil {
return nil, errDisableUser
}
}
}
if errSyncRole := ls.syncOrgRoles(ctx, cmd.Result, extUser); errSyncRole != nil {
return errSyncRole
if errSyncRole := ls.syncOrgRoles(ctx, result, extUser); errSyncRole != nil {
return nil, errSyncRole
}
// Sync isGrafanaAdmin permission
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
if errPerms := ls.userService.UpdatePermissions(ctx, cmd.Result.ID, *extUser.IsGrafanaAdmin); errPerms != nil {
return errPerms
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != result.IsAdmin {
if errPerms := ls.userService.UpdatePermissions(ctx, result.ID, *extUser.IsGrafanaAdmin); errPerms != nil {
return nil, errPerms
}
}
// There are external providers where we want to completely skip team synchronization see - https://github.com/grafana/grafana/issues/62175
if ls.TeamSync != nil && !extUser.SkipTeamSync {
if errTeamSync := ls.TeamSync(cmd.Result, extUser); errTeamSync != nil {
return errTeamSync
if errTeamSync := ls.TeamSync(result, extUser); errTeamSync != nil {
return nil, errTeamSync
}
}
return nil
return result, nil
}
func (ls *Implementation) DisableExternalUser(ctx context.Context, username string) error {
@ -174,11 +174,11 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
LoginOrEmail: username,
}
if err := ls.AuthInfoService.GetExternalUserInfoByLogin(ctx, userQuery); err != nil {
userInfo, err := ls.AuthInfoService.GetExternalUserInfoByLogin(ctx, userQuery)
if err != nil {
return err
}
userInfo := userQuery.Result
if userInfo.IsDisabled {
return nil
}
@ -186,12 +186,12 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
logger.Debug(
"Disabling external user",
"user",
userQuery.Result.Login,
userInfo.Login,
)
// Mark user as disabled in grafana db
disableUserCmd := &user.DisableUserCommand{
UserID: userQuery.Result.UserId,
UserID: userInfo.UserId,
IsDisabled: true,
}
@ -199,7 +199,7 @@ func (ls *Implementation) DisableExternalUser(ctx context.Context, username stri
logger.Debug(
"Error disabling external user",
"user",
userQuery.Result.Login,
userInfo.Login,
"message",
err.Error(),
)

View File

@ -14,13 +14,11 @@ type LoginServiceMock struct {
ExpectedError error
}
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (*user.User, error) {
if s.ExpectedUserFunc != nil {
cmd.Result = s.ExpectedUserFunc(cmd)
return s.ExpectedError
return s.ExpectedUserFunc(cmd), s.ExpectedError
}
cmd.Result = s.ExpectedUser
return s.ExpectedError
return s.ExpectedUser, s.ExpectedError
}
func (s LoginServiceMock) DisableExternalUser(ctx context.Context, username string) error {

View File

@ -85,7 +85,7 @@ func Test_teamSync(t *testing.T) {
var actualExternalUser *login.ExternalUserInfo
t.Run("login.TeamSync should not be called when nil", func(t *testing.T) {
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
require.Nil(t, err)
assert.Nil(t, actualUser)
assert.Nil(t, actualExternalUser)
@ -97,7 +97,7 @@ func Test_teamSync(t *testing.T) {
return nil
}
loginsvc.TeamSync = teamSyncFunc
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
require.Nil(t, err)
assert.Equal(t, actualUser, expectedUser)
assert.Equal(t, actualExternalUser, upsertCmd.ExternalUser)
@ -120,7 +120,7 @@ func Test_teamSync(t *testing.T) {
return nil
}
loginsvc.TeamSync = teamSyncFunc
err := loginsvc.UpsertUser(context.Background(), upsertCmdSkipTeamSync)
_, err := loginsvc.UpsertUser(context.Background(), upsertCmdSkipTeamSync)
require.Nil(t, err)
assert.Nil(t, actualUser)
assert.Nil(t, actualExternalUser)
@ -131,7 +131,7 @@ func Test_teamSync(t *testing.T) {
return errors.New("teamsync test error")
}
loginsvc.TeamSync = teamSyncFunc
err := loginsvc.UpsertUser(context.Background(), upsertCmd)
_, err := loginsvc.UpsertUser(context.Background(), upsertCmd)
require.Error(t, err)
})
})
@ -154,7 +154,7 @@ func TestUpsertUser_crashOnLog_issue62538(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = loginsvc.UpsertUser(context.Background(), upsertCmd)
_, err = loginsvc.UpsertUser(context.Background(), upsertCmd)
})
require.ErrorIs(t, err, login.ErrSignupNotAllowed)
}

View File

@ -9,8 +9,8 @@ import (
type LoginServiceFake struct{}
func (l *LoginServiceFake) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) error {
return nil
func (l *LoginServiceFake) UpsertUser(ctx context.Context, cmd *login.UpsertUserCommand) (*user.User, error) {
return nil, nil
}
func (l *LoginServiceFake) DisableExternalUser(ctx context.Context, username string) error {
return nil
@ -39,10 +39,9 @@ func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *login.
return a.ExpectedUser, a.ExpectedError
}
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
a.LatestUserID = query.UserId
query.Result = a.ExpectedUserAuth
return a.ExpectedError
return a.ExpectedUserAuth, a.ExpectedError
}
func (a *AuthInfoServiceFake) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
@ -65,9 +64,8 @@ func (a *AuthInfoServiceFake) UpdateAuthInfo(ctx context.Context, cmd *login.Upd
return a.ExpectedError
}
func (a *AuthInfoServiceFake) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
query.Result = a.ExpectedExternalUser
return a.ExpectedError
func (a *AuthInfoServiceFake) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
return a.ExpectedExternalUser, a.ExpectedError
}
func (a *AuthInfoServiceFake) DeleteUserAuthInfo(ctx context.Context, userID int64) error {

View File

@ -91,8 +91,6 @@ type UpsertUserCommand struct {
ExternalUser *ExternalUserInfo
UserLookupParams
SignupAllowed bool
Result *user.User
}
type SetAuthInfoCommand struct {
@ -141,16 +139,12 @@ type UserLookupParams struct {
type GetExternalUserInfoByLoginQuery struct {
LoginOrEmail string
Result *ExternalUserInfo
}
type GetAuthInfoQuery struct {
UserId int64
AuthModule string
AuthId string
Result *UserAuth
}
type GetUserLabelsQuery struct {

View File

@ -59,7 +59,8 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
@ -69,10 +70,10 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
return nil
}
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfo)
if err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfoQuery.Result)
return buildOAuthTokenFromAuthInfo(authInfo)
}
return nil
@ -94,7 +95,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*l
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
@ -103,10 +104,10 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*l
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
return nil, false, err
}
if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") {
if !strings.Contains(authInfo.AuthModule, "oauth") {
return nil, false, nil
}
return authInfoQuery.Result, true, nil
return authInfo, true, nil
}
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful

View File

@ -119,12 +119,11 @@ func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &login.GetAuthInfoQuery{}
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
assert.Nil(t, err)
// User's token data had not been updated
resultUsr := authInfoQuery.Result
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
@ -193,15 +192,15 @@ func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &login.GetAuthInfoQuery{}
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
assert.Nil(t, err)
// newToken should be returned after the .Token() call, therefore the User had to be updated
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken)
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry)
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken)
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType)
assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken)
assert.Equal(t, authInfo.OAuthExpiry, newToken.Expiry)
assert.Equal(t, authInfo.OAuthRefreshToken, newToken.RefreshToken)
assert.Equal(t, authInfo.OAuthTokenType, newToken.TokenType)
}
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
@ -318,13 +317,12 @@ type FakeAuthInfoStore struct {
ExpectedLoginStats login.LoginStats
}
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) error {
return f.ExpectedError
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *login.GetExternalUserInfoByLoginQuery) (*login.ExternalUserInfo, error) {
return nil, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) error {
query.Result = f.ExpectedOAuth
return f.ExpectedError
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return f.ExpectedOAuth, f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {