From d55e5b886223f5abdc6197c94be5fb5c57987a83 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Fri, 21 Oct 2022 15:21:21 +0200 Subject: [PATCH] [main] Login email before username (#57400) * Swap order of login fields * Validate email field before validating the username field. Co-authored-by: linoman <2051016+linoman@users.noreply.github.com> --- pkg/services/user/userimpl/store.go | 42 ++++++++++++------------ pkg/services/user/userimpl/store_test.go | 41 +++++++++++++++++++++++ 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/pkg/services/user/userimpl/store.go b/pkg/services/user/userimpl/store.go index 05680365edb..a27850bfb4d 100644 --- a/pkg/services/user/userimpl/store.go +++ b/pkg/services/user/userimpl/store.go @@ -170,27 +170,30 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu return user.ErrUserNotFound } - // Try and find the user by login first. - // It's not sufficient to assume that a LoginOrEmail with an "@" is an email. - where := "login=?" - if ss.cfg.CaseInsensitiveLogin { - where = "LOWER(login)=LOWER(?)" - } - - has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr) - if err != nil { - return err - } - - if !has && strings.Contains(query.LoginOrEmail, "@") { - // If the user wasn't found, and it contains an "@" fallback to finding the - // user by email. + var where string + var has bool + var err error + // Since username can be an email address, attempt login with email address + // first if the login field has the "@" symbol. + if strings.Contains(query.LoginOrEmail, "@") { where = "email=?" if ss.cfg.CaseInsensitiveLogin { where = "LOWER(email)=LOWER(?)" } - usr = &user.User{} + has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr) + + if err != nil { + return err + } + } + + // Look for the login field instead of email + if !has { + where = "login=?" + if ss.cfg.CaseInsensitiveLogin { + where = "LOWER(login)=LOWER(?)" + } has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr) } @@ -199,7 +202,6 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu } else if !has { return user.ErrUserNotFound } - if ss.cfg.CaseInsensitiveLogin { if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil { return err @@ -207,10 +209,8 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu } return nil }) - if err != nil { - return nil, err - } - return usr, nil + + return usr, err } func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) { diff --git a/pkg/services/user/userimpl/store_test.go b/pkg/services/user/userimpl/store_test.go index 3fb443f8190..9f108f16926 100644 --- a/pkg/services/user/userimpl/store_test.go +++ b/pkg/services/user/userimpl/store_test.go @@ -202,6 +202,47 @@ func TestIntegrationUserDataAccess(t *testing.T) { require.Error(t, err) }) + t.Run("GetByLogin - user2 uses user1.email as login", func(t *testing.T) { + // create user_1 + user1 := &user.User{ + Email: "user_1@mail.com", + Name: "user_1", + Login: "user_1", + Password: "user_1_password", + Created: time.Now(), + Updated: time.Now(), + IsDisabled: true, + } + _, err := userStore.Insert(context.Background(), user1) + require.Nil(t, err) + + // create user_2 + user2 := &user.User{ + Email: "user_2@mail.com", + Name: "user_2", + Login: "user_1@mail.com", + Password: "user_2_password", + Created: time.Now(), + Updated: time.Now(), + IsDisabled: true, + } + _, err = userStore.Insert(context.Background(), user2) + require.Nil(t, err) + + // query user database for user_1 email + query := user.GetUserByLoginQuery{LoginOrEmail: "user_1@mail.com"} + result, err := userStore.GetByLogin(context.Background(), &query) + require.Nil(t, err) + + // expect user_1 as result + require.Equal(t, user1.Email, result.Email) + require.Equal(t, user1.Login, result.Login) + require.Equal(t, user1.Name, result.Name) + require.NotEqual(t, user2.Email, result.Email) + require.NotEqual(t, user2.Login, result.Login) + require.NotEqual(t, user2.Name, result.Name) + }) + ss.Cfg.CaseInsensitiveLogin = false })