From 72db9739c7597cbdfa203e3ece3c65f2099554f7 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Mon, 20 Nov 2023 14:58:32 +0100 Subject: [PATCH] UserAuth: clean-up auth entries on update (#78377) * UserAuth: clean-up auth entries on update * Update pkg/services/login/authinfoservice/database/database.go Co-authored-by: Gabriel MABILLE --------- Co-authored-by: Gabriel MABILLE --- .../authinfoservice/database/database.go | 29 +++++++- .../authinfoservice/database/database_test.go | 68 +++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 pkg/services/login/authinfoservice/database/database_test.go diff --git a/pkg/services/login/authinfoservice/database/database.go b/pkg/services/login/authinfoservice/database/database.go index e3c384f889e..992d74bd864 100644 --- a/pkg/services/login/authinfoservice/database/database.go +++ b/pkg/services/login/authinfoservice/database/database.go @@ -165,6 +165,7 @@ func (s *AuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login. } return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error { _, err := sess.Cols("created").Update(authInfo, cond) + return err }) } @@ -208,8 +209,32 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAut return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error { upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser) - s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, - "auth_id", cmd.AuthId, "auth_module", cmd.AuthModule, "rows", upd) + + s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_id", cmd.AuthId, "auth_module", cmd.AuthModule, "rows", upd) + + // Clean up duplicated entries + if upd > 1 { + var id int64 + ok, err := sess.SQL( + "SELECT id FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ?", + cmd.UserId, cmd.AuthModule, cmd.AuthId, + ).Get(&id) + + if err != nil { + return err + } + + if !ok { + return nil + } + + _, err = sess.Exec( + "DELETE FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ? AND id != ?", + cmd.UserId, cmd.AuthModule, cmd.AuthId, id, + ) + return err + } + return err }) } diff --git a/pkg/services/login/authinfoservice/database/database_test.go b/pkg/services/login/authinfoservice/database/database_test.go new file mode 100644 index 00000000000..98f66fe3a9a --- /dev/null +++ b/pkg/services/login/authinfoservice/database/database_test.go @@ -0,0 +1,68 @@ +package database + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/services/login" + secretstest "github.com/grafana/grafana/pkg/services/secrets/fakes" +) + +func TestIntegrationAuthInfoStore(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + sql := db.InitTestDB(t) + store := ProvideAuthInfoStore(sql, secretstest.NewFakeSecretsService(), nil) + + t.Run("should remove duplicates on update", func(t *testing.T) { + ctx := context.Background() + setCmd := &login.SetAuthInfoCommand{ + AuthModule: login.GenericOAuthModule, + AuthId: "1", + UserId: 1, + } + + require.NoError(t, store.SetAuthInfo(ctx, setCmd)) + require.NoError(t, store.SetAuthInfo(ctx, setCmd)) + + count := countEntries(t, sql, setCmd.AuthModule, setCmd.AuthId, setCmd.UserId) + require.Equal(t, 2, count) + + err := store.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{ + AuthModule: setCmd.AuthModule, + AuthId: setCmd.AuthId, + UserId: setCmd.UserId, + OAuthToken: &oauth2.Token{ + AccessToken: "atoken", + RefreshToken: "rtoken", + Expiry: time.Now(), + }, + }) + require.NoError(t, err) + + count = countEntries(t, sql, setCmd.AuthModule, setCmd.AuthId, setCmd.UserId) + require.Equal(t, 1, count) + }) +} + +func countEntries(t *testing.T, sql db.DB, authModule, authID string, userID int64) int { + var result int + + err := sql.WithDbSession(context.Background(), func(sess *db.Session) error { + _, err := sess.SQL( + "SELECT COUNT(*) FROM user_auth WHERE auth_module = ? AND auth_id = ? AND user_id = ?", + authModule, authID, userID, + ).Get(&result) + return err + }) + + require.NoError(t, err) + return result +}