Service Accounts: Refactor login service to use sqlstore methods (#46585)

* refactor login service to use sqlstore methods

* trailing newline
This commit is contained in:
Jguer 2022-03-15 15:57:21 +00:00 committed by GitHub
parent 11aa6a3e8f
commit 04267a66ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 56 deletions

View File

@ -214,33 +214,31 @@ func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAu
})
}
func (s *AuthInfoStore) GetUserById(id int64) (bool, *models.User, error) {
var (
has bool
err error
)
user := &models.User{}
err = s.sqlStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
has, err = sess.ID(id).Get(user)
return err
})
if err != nil {
return false, nil, err
func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*models.User, error) {
query := models.GetUserByIdQuery{Id: id}
if err := s.sqlStore.GetUserById(ctx, &query); err != nil {
return nil, err
}
return has, user, nil
return query.Result, nil
}
func (s *AuthInfoStore) GetUser(user *models.User) (bool, error) {
var err error
var has bool
func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
query := models.GetUserByLoginQuery{LoginOrEmail: login}
if err := s.sqlStore.GetUserByLogin(ctx, &query); err != nil {
return nil, err
}
err = s.sqlStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
has, err = sess.Get(user)
return err
})
return query.Result, nil
}
return has, err
func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := models.GetUserByEmailQuery{Email: email}
if err := s.sqlStore.GetUserByEmail(ctx, &query); err != nil {
return nil, err
}
return query.Result, nil
}
// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it

View File

@ -52,21 +52,20 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser
return false, nil, nil, models.ErrUserNotFound
} else {
has, user, err := s.authInfoStore.GetUserById(authQuery.Result.UserId)
user, err := s.authInfoStore.GetUserById(ctx, authQuery.Result.UserId)
if err != nil {
return false, nil, nil, err
}
if errors.Is(err, models.ErrUserNotFound) {
// if the user has been deleted then remove the entry
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
}); errDel != nil {
s.logger.Error("Error removing user_auth entry", "error", errDel)
}
if !has {
// if the user has been deleted then remove the entry
err = s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
})
if err != nil {
s.logger.Error("Error removing user_auth entry", "error", err)
return false, nil, nil, models.ErrUserNotFound
}
return false, nil, nil, models.ErrUserNotFound
return false, nil, nil, err
}
return true, user, authQuery.Result, nil
@ -77,42 +76,39 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser
return false, nil, nil, models.ErrUserNotFound
}
func (s *Implementation) LookupByOneOf(userId int64, email string, login string) (bool, *models.User, error) {
foundUser := false
func (s *Implementation) LookupByOneOf(ctx context.Context, userId int64, email string, login string) (*models.User, error) {
var user *models.User
var err error
// If not found, try to find the user by id
if userId != 0 {
foundUser, user, err = s.authInfoStore.GetUserById(userId)
if err != nil {
return false, nil, err
user, err = s.authInfoStore.GetUserById(ctx, userId)
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
return nil, err
}
}
// If not found, try to find the user by email address
if !foundUser && email != "" {
user = &models.User{Email: email}
foundUser, err = s.authInfoStore.GetUser(user)
if err != nil {
return false, nil, err
if user == nil && email != "" {
user, err = s.authInfoStore.GetUserByEmail(ctx, email)
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
return nil, err
}
}
// If not found, try to find the user by login
if !foundUser && login != "" {
user = &models.User{Login: login}
foundUser, err = s.authInfoStore.GetUser(user)
if err != nil {
return false, nil, err
if user == nil && login != "" {
user, err = s.authInfoStore.GetUserByLogin(ctx, login)
if err != nil && !errors.Is(err, models.ErrUserNotFound) {
return nil, err
}
}
if !foundUser {
return false, nil, models.ErrUserNotFound
if user == nil {
return nil, models.ErrUserNotFound
}
return foundUser, user, nil
return user, nil
}
func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) {
@ -141,7 +137,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetU
// 2. FindByUserDetails
if !foundUser {
_, user, err = s.LookupByOneOf(query.UserId, query.Email, query.Login)
user, err = s.LookupByOneOf(ctx, query.UserId, query.Email, query.Login)
if err != nil {
return nil, err
}

View File

@ -47,7 +47,7 @@ func TestUserAuth(t *testing.T) {
// By ID
id := user.Id
_, user, err = srv.LookupByOneOf(id, "", "")
user, err = srv.LookupByOneOf(context.Background(), id, "", "")
require.Nil(t, err)
require.Equal(t, user.Id, id)
@ -55,7 +55,7 @@ func TestUserAuth(t *testing.T) {
// By Email
email := "user1@test.com"
_, user, err = srv.LookupByOneOf(0, email, "")
user, err = srv.LookupByOneOf(context.Background(), 0, email, "")
require.Nil(t, err)
require.Equal(t, user.Email, email)
@ -63,7 +63,7 @@ func TestUserAuth(t *testing.T) {
// Don't find nonexistent user
email = "nonexistent@test.com"
_, user, err = srv.LookupByOneOf(0, email, "")
user, err = srv.LookupByOneOf(context.Background(), 0, email, "")
require.Equal(t, models.ErrUserNotFound, err)
require.Nil(t, user)

View File

@ -16,6 +16,7 @@ type Store interface {
SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error
UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error
DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error
GetUserById(id int64) (bool, *models.User, error)
GetUser(user *models.User) (bool, error)
GetUserById(ctx context.Context, id int64) (*models.User, error)
GetUserByLogin(ctx context.Context, login string) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
}