Service Accounts: Refactor login service to use sqlstore methods (#46585)

* refactor login service to use sqlstore methods

* trailing newline
This commit is contained in:
Jguer
2022-03-15 15:57:21 +00:00
committed by GitHub
parent 11aa6a3e8f
commit 04267a66ec
4 changed files with 51 additions and 56 deletions

View File

@@ -214,33 +214,31 @@ func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAu
}) })
} }
func (s *AuthInfoStore) GetUserById(id int64) (bool, *models.User, error) { func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*models.User, error) {
var ( query := models.GetUserByIdQuery{Id: id}
has bool if err := s.sqlStore.GetUserById(ctx, &query); err != nil {
err error return nil, err
)
user := &models.User{}
err = s.sqlStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
has, err = sess.ID(id).Get(user)
return err
})
if err != nil {
return false, nil, err
} }
return has, user, nil return query.Result, nil
} }
func (s *AuthInfoStore) GetUser(user *models.User) (bool, error) { func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
var err error query := models.GetUserByLoginQuery{LoginOrEmail: login}
var has bool if err := s.sqlStore.GetUserByLogin(ctx, &query); err != nil {
return nil, err
}
err = s.sqlStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { return query.Result, nil
has, err = sess.Get(user) }
return err
})
return has, err func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := models.GetUserByEmailQuery{Email: email}
if err := s.sqlStore.GetUserByEmail(ctx, &query); err != nil {
return nil, err
}
return query.Result, nil
} }
// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it // decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it

View File

@@ -52,21 +52,20 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser
return false, nil, nil, models.ErrUserNotFound return false, nil, nil, models.ErrUserNotFound
} else { } else {
has, user, err := s.authInfoStore.GetUserById(authQuery.Result.UserId) user, err := s.authInfoStore.GetUserById(ctx, authQuery.Result.UserId)
if err != nil { if err != nil {
return false, nil, nil, err if errors.Is(err, models.ErrUserNotFound) {
} // if the user has been deleted then remove the entry
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
}); errDel != nil {
s.logger.Error("Error removing user_auth entry", "error", errDel)
}
if !has { return false, nil, nil, models.ErrUserNotFound
// if the user has been deleted then remove the entry
err = s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{
UserAuth: authQuery.Result,
})
if err != nil {
s.logger.Error("Error removing user_auth entry", "error", err)
} }
return false, nil, nil, models.ErrUserNotFound return false, nil, nil, err
} }
return true, user, authQuery.Result, nil return true, user, authQuery.Result, nil
@@ -77,42 +76,39 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser
return false, nil, nil, models.ErrUserNotFound return false, nil, nil, models.ErrUserNotFound
} }
func (s *Implementation) LookupByOneOf(userId int64, email string, login string) (bool, *models.User, error) { func (s *Implementation) LookupByOneOf(ctx context.Context, userId int64, email string, login string) (*models.User, error) {
foundUser := false
var user *models.User var user *models.User
var err error var err error
// If not found, try to find the user by id // If not found, try to find the user by id
if userId != 0 { if userId != 0 {
foundUser, user, err = s.authInfoStore.GetUserById(userId) user, err = s.authInfoStore.GetUserById(ctx, userId)
if err != nil { if err != nil && !errors.Is(err, models.ErrUserNotFound) {
return false, nil, err return nil, err
} }
} }
// If not found, try to find the user by email address // If not found, try to find the user by email address
if !foundUser && email != "" { if user == nil && email != "" {
user = &models.User{Email: email} user, err = s.authInfoStore.GetUserByEmail(ctx, email)
foundUser, err = s.authInfoStore.GetUser(user) if err != nil && !errors.Is(err, models.ErrUserNotFound) {
if err != nil { return nil, err
return false, nil, err
} }
} }
// If not found, try to find the user by login // If not found, try to find the user by login
if !foundUser && login != "" { if user == nil && login != "" {
user = &models.User{Login: login} user, err = s.authInfoStore.GetUserByLogin(ctx, login)
foundUser, err = s.authInfoStore.GetUser(user) if err != nil && !errors.Is(err, models.ErrUserNotFound) {
if err != nil { return nil, err
return false, nil, err
} }
} }
if !foundUser { if user == nil {
return false, nil, models.ErrUserNotFound return nil, models.ErrUserNotFound
} }
return foundUser, user, nil return user, nil
} }
func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) { func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) {
@@ -141,7 +137,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetU
// 2. FindByUserDetails // 2. FindByUserDetails
if !foundUser { if !foundUser {
_, user, err = s.LookupByOneOf(query.UserId, query.Email, query.Login) user, err = s.LookupByOneOf(ctx, query.UserId, query.Email, query.Login)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -47,7 +47,7 @@ func TestUserAuth(t *testing.T) {
// By ID // By ID
id := user.Id id := user.Id
_, user, err = srv.LookupByOneOf(id, "", "") user, err = srv.LookupByOneOf(context.Background(), id, "", "")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Id, id) require.Equal(t, user.Id, id)
@@ -55,7 +55,7 @@ func TestUserAuth(t *testing.T) {
// By Email // By Email
email := "user1@test.com" email := "user1@test.com"
_, user, err = srv.LookupByOneOf(0, email, "") user, err = srv.LookupByOneOf(context.Background(), 0, email, "")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Email, email) require.Equal(t, user.Email, email)
@@ -63,7 +63,7 @@ func TestUserAuth(t *testing.T) {
// Don't find nonexistent user // Don't find nonexistent user
email = "nonexistent@test.com" email = "nonexistent@test.com"
_, user, err = srv.LookupByOneOf(0, email, "") user, err = srv.LookupByOneOf(context.Background(), 0, email, "")
require.Equal(t, models.ErrUserNotFound, err) require.Equal(t, models.ErrUserNotFound, err)
require.Nil(t, user) require.Nil(t, user)

View File

@@ -16,6 +16,7 @@ type Store interface {
SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error
UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error
DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error
GetUserById(id int64) (bool, *models.User, error) GetUserById(ctx context.Context, id int64) (*models.User, error)
GetUser(user *models.User) (bool, error) GetUserByLogin(ctx context.Context, login string) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
} }