Generic OAuth: Prevent adding duplicated users (#32286)

* Add special check for generic oauth case

* Converted from Convey to testify

* Fix according to reviewer's comments

* More changes according to reviewer's comments

* Handle error if user is not found

* Move generic oauth test from user_test.go to user_auth_test.go

* Update pkg/services/sqlstore/user_auth_test.go

Co-authored-by: Marcus Efraimsson <marcus.efraimsson@gmail.com>

* Created genericOAuthModule const

Co-authored-by: Marcus Efraimsson <marcus.efraimsson@gmail.com>
This commit is contained in:
Dimitris Sotirakis 2021-04-09 14:28:35 +03:00 committed by GitHub
parent e6a98ce1e4
commit b867ceda9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 715 additions and 631 deletions

View File

@ -13,6 +13,8 @@ import (
var getTime = time.Now
const genericOAuthModule = "oauth_generic_oauth"
func init() {
bus.AddHandler("sql", GetUserByAuthInfo)
bus.AddHandler("sql", GetExternalUserInfoByLogin)
@ -101,7 +103,17 @@ func GetUserByAuthInfo(query *models.GetUserByAuthInfoQuery) error {
return models.ErrUserNotFound
}
// create authInfo record to link accounts
// Special case for generic oauth duplicates
if query.AuthModule == genericOAuthModule && user.Id != 0 {
authQuery.UserId = user.Id
authQuery.AuthModule = query.AuthModule
err = GetAuthInfo(authQuery)
if !errors.Is(err, models.ErrUserNotFound) {
if err != nil {
return err
}
}
}
if authQuery.Result == nil && query.AuthModule != "" {
cmd2 := &models.SetAuthInfoCommand{
UserId: user.Id,
@ -151,6 +163,7 @@ func GetAuthInfo(query *models.GetAuthInfoQuery) error {
if err != nil {
return err
}
if !has {
return models.ErrUserNotFound
}

View File

@ -5,11 +5,11 @@ package sqlstore
import (
"context"
"fmt"
"github.com/stretchr/testify/require"
"testing"
"time"
"github.com/grafana/grafana/pkg/models"
. "github.com/smartystreets/goconvey/convey"
"golang.org/x/oauth2"
)
@ -17,7 +17,7 @@ import (
func TestUserAuth(t *testing.T) {
sqlStore := InitTestDB(t)
Convey("Given 5 users", t, func() {
t.Run("Given 5 users", func(t *testing.T) {
for i := 0; i < 5; i++ {
cmd := models.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
@ -25,29 +25,18 @@ func TestUserAuth(t *testing.T) {
Login: fmt.Sprint("loginuser", i),
}
_, err := sqlStore.CreateUser(context.Background(), cmd)
So(err, ShouldBeNil)
require.Nil(t, err)
}
Reset(func() {
_, err := x.Exec("DELETE FROM org_user WHERE 1=1")
So(err, ShouldBeNil)
_, err = x.Exec("DELETE FROM org WHERE 1=1")
So(err, ShouldBeNil)
_, err = x.Exec("DELETE FROM " + dialect.Quote("user") + " WHERE 1=1")
So(err, ShouldBeNil)
_, err = x.Exec("DELETE FROM user_auth WHERE 1=1")
So(err, ShouldBeNil)
})
Convey("Can find existing user", func() {
t.Run("Can find existing user", func(t *testing.T) {
// By Login
login := "loginuser0"
query := &models.GetUserByAuthInfoQuery{Login: login}
err := GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// By ID
id := query.Result.Id
@ -55,8 +44,8 @@ func TestUserAuth(t *testing.T) {
query = &models.GetUserByAuthInfoQuery{UserId: id}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Id, ShouldEqual, id)
require.Nil(t, err)
require.Equal(t, query.Result.Id, id)
// By Email
email := "user1@test.com"
@ -64,8 +53,8 @@ func TestUserAuth(t *testing.T) {
query = &models.GetUserByAuthInfoQuery{Email: email}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Email, ShouldEqual, email)
require.Nil(t, err)
require.Equal(t, query.Result.Email, email)
// Don't find nonexistent user
email = "nonexistent@test.com"
@ -73,17 +62,17 @@ func TestUserAuth(t *testing.T) {
query = &models.GetUserByAuthInfoQuery{Email: email}
err = GetUserByAuthInfo(query)
So(err, ShouldEqual, models.ErrUserNotFound)
So(query.Result, ShouldBeNil)
require.Equal(t, err, models.ErrUserNotFound)
require.Nil(t, query.Result)
})
Convey("Can set & locate by AuthModule and AuthId", func() {
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
// get nonexistent user_auth entry
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
err := GetUserByAuthInfo(query)
So(err, ShouldEqual, models.ErrUserNotFound)
So(query.Result, ShouldBeNil)
require.Equal(t, err, models.ErrUserNotFound)
require.Nil(t, query.Result)
// create user_auth entry
login := "loginuser0"
@ -91,15 +80,15 @@ func TestUserAuth(t *testing.T) {
query.Login = login
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// get via user_auth
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// get with non-matching id
id := query.Result.Id
@ -107,29 +96,29 @@ func TestUserAuth(t *testing.T) {
query.UserId = id + 1
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, "loginuser1")
require.Nil(t, err)
require.Equal(t, query.Result.Login, "loginuser1")
// get via user_auth
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, "loginuser1")
require.Nil(t, err)
require.Equal(t, query.Result.Login, "loginuser1")
// remove user
_, err = x.Exec("DELETE FROM "+dialect.Quote("user")+" WHERE id=?", query.Result.Id)
So(err, ShouldBeNil)
require.Nil(t, err)
// get via user_auth for deleted user
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
err = GetUserByAuthInfo(query)
So(err, ShouldEqual, models.ErrUserNotFound)
So(query.Result, ShouldBeNil)
require.Equal(t, err, models.ErrUserNotFound)
require.Nil(t, query.Result)
})
Convey("Can set & retrieve oauth token information", func() {
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
@ -144,8 +133,8 @@ func TestUserAuth(t *testing.T) {
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
err := GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
cmd := &models.UpdateAuthInfoCommand{
UserId: query.Result.Id,
@ -155,7 +144,7 @@ func TestUserAuth(t *testing.T) {
}
err = UpdateAuthInfo(cmd)
So(err, ShouldBeNil)
require.Nil(t, err)
getAuthQuery := &models.GetAuthInfoQuery{
UserId: query.Result.Id,
@ -163,13 +152,26 @@ func TestUserAuth(t *testing.T) {
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
require.Equal(t, getAuthQuery.Result.OAuthRefreshToken, token.RefreshToken)
require.Equal(t, getAuthQuery.Result.OAuthTokenType, token.TokenType)
})
Convey("Always return the most recently used auth_module", func() {
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
// Restore after destructive operation
sqlStore = InitTestDB(t)
for i := 0; i < 5; i++ {
cmd := models.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
}
_, err := sqlStore.CreateUser(context.Background(), cmd)
require.Nil(t, err)
}
// Find a user to set tokens on
login := "loginuser0"
@ -180,8 +182,8 @@ func TestUserAuth(t *testing.T) {
err := GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// Add a second auth module for this user
// Have this module's last log-in be more recent
@ -190,8 +192,8 @@ func TestUserAuth(t *testing.T) {
err = GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery := &models.GetAuthInfoQuery{
@ -200,14 +202,14 @@ func TestUserAuth(t *testing.T) {
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test2")
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test2")
// "log in" again with the first auth module
updateAuthCmd := &models.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
err = UpdateAuthInfo(updateAuthCmd)
So(err, ShouldBeNil)
require.Nil(t, err)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery = &models.GetAuthInfoQuery{
@ -216,8 +218,31 @@ func TestUserAuth(t *testing.T) {
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test1")
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.AuthModule, "test1")
})
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
// Find a user to set tokens on
login := "loginuser0"
// Expect to pass since there's a matching login user
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""}
err := GetUserByAuthInfo(query)
getTime = time.Now
require.Nil(t, err)
require.Equal(t, query.Result.Login, login)
// Should throw a "user not found" error since there's no matching login user
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""}
err = GetUserByAuthInfo(query)
getTime = time.Now
require.NotNil(t, err)
require.Nil(t, query.Result)
})
})
}

File diff suppressed because it is too large Load Diff