[main] Login email before username (#57400)

* Swap order of login fields

* Validate email field before validating the username field.

Co-authored-by: linoman <2051016+linoman@users.noreply.github.com>
This commit is contained in:
Karl Persson 2022-10-21 15:21:21 +02:00 committed by GitHub
parent 7e631e7239
commit d55e5b8862
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 21 deletions

View File

@ -170,27 +170,30 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu
return user.ErrUserNotFound
}
// Try and find the user by login first.
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
where := "login=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}
has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
if err != nil {
return err
}
if !has && strings.Contains(query.LoginOrEmail, "@") {
// If the user wasn't found, and it contains an "@" fallback to finding the
// user by email.
var where string
var has bool
var err error
// Since username can be an email address, attempt login with email address
// first if the login field has the "@" symbol.
if strings.Contains(query.LoginOrEmail, "@") {
where = "email=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
usr = &user.User{}
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
if err != nil {
return err
}
}
// Look for the login field instead of email
if !has {
where = "login=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
}
@ -199,7 +202,6 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu
} else if !has {
return user.ErrUserNotFound
}
if ss.cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
@ -207,10 +209,8 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu
}
return nil
})
if err != nil {
return nil, err
}
return usr, nil
return usr, err
}
func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {

View File

@ -202,6 +202,47 @@ func TestIntegrationUserDataAccess(t *testing.T) {
require.Error(t, err)
})
t.Run("GetByLogin - user2 uses user1.email as login", func(t *testing.T) {
// create user_1
user1 := &user.User{
Email: "user_1@mail.com",
Name: "user_1",
Login: "user_1",
Password: "user_1_password",
Created: time.Now(),
Updated: time.Now(),
IsDisabled: true,
}
_, err := userStore.Insert(context.Background(), user1)
require.Nil(t, err)
// create user_2
user2 := &user.User{
Email: "user_2@mail.com",
Name: "user_2",
Login: "user_1@mail.com",
Password: "user_2_password",
Created: time.Now(),
Updated: time.Now(),
IsDisabled: true,
}
_, err = userStore.Insert(context.Background(), user2)
require.Nil(t, err)
// query user database for user_1 email
query := user.GetUserByLoginQuery{LoginOrEmail: "user_1@mail.com"}
result, err := userStore.GetByLogin(context.Background(), &query)
require.Nil(t, err)
// expect user_1 as result
require.Equal(t, user1.Email, result.Email)
require.Equal(t, user1.Login, result.Login)
require.Equal(t, user1.Name, result.Name)
require.NotEqual(t, user2.Email, result.Email)
require.NotEqual(t, user2.Login, result.Login)
require.NotEqual(t, user2.Name, result.Name)
})
ss.Cfg.CaseInsensitiveLogin = false
})