diff --git a/pkg/cmd/grafana-cli/commands/datamigrations/encrypt_datasource_passwords_test.go b/pkg/cmd/grafana-cli/commands/datamigrations/encrypt_datasource_passwords_test.go index 98b99de6424..114f5291111 100644 --- a/pkg/cmd/grafana-cli/commands/datamigrations/encrypt_datasource_passwords_test.go +++ b/pkg/cmd/grafana-cli/commands/datamigrations/encrypt_datasource_passwords_test.go @@ -17,10 +17,15 @@ import ( func TestPasswordMigrationCommand(t *testing.T) { // setup datasources with password, basic_auth and none - sqlstore := sqlstore.InitTestDB(t) - session := sqlstore.NewSession(context.Background()) - defer session.Close() + store := sqlstore.InitTestDB(t) + err := store.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + passwordMigration(t, sess, store) + return nil + }) + require.NoError(t, err) +} +func passwordMigration(t *testing.T, session *sqlstore.DBSession, sqlstore *sqlstore.SQLStore) { ds := []*datasources.DataSource{ {Type: "influxdb", Name: "influxdb", Password: "foobar", Uid: "influx"}, {Type: "graphite", Name: "graphite", BasicAuthPassword: "foobar", Uid: "graphite"}, diff --git a/pkg/infra/remotecache/database_storage.go b/pkg/infra/remotecache/database_storage.go index d9ff8702b01..ddba108693a 100644 --- a/pkg/infra/remotecache/database_storage.go +++ b/pkg/infra/remotecache/database_storage.go @@ -54,36 +54,38 @@ func (dc *databaseCache) internalRunGC() { func (dc *databaseCache) Get(ctx context.Context, key string) (interface{}, error) { cacheHit := CacheData{} - session := dc.SQLStore.NewSession(ctx) - defer session.Close() - - exist, err := session.Where("cache_key= ?", key).Get(&cacheHit) - - if err != nil { - return nil, err - } - - if !exist { - return nil, ErrCacheItemNotFound - } - - if cacheHit.Expires > 0 { - existedButExpired := getTime().Unix()-cacheHit.CreatedAt >= cacheHit.Expires - if existedButExpired { - err = dc.Delete(ctx, key) // ignore this error since we will return `ErrCacheItemNotFound` anyway - if err != nil { - dc.log.Debug("Deletion of expired key failed: %v", err) - } - return nil, ErrCacheItemNotFound - } - } item := &cachedItem{} - if err = decodeGob(cacheHit.Data, item); err != nil { - return nil, err - } + err := dc.SQLStore.WithDbSession(ctx, func(session *sqlstore.DBSession) error { + exist, err := session.Where("cache_key= ?", key).Get(&cacheHit) - return item.Val, nil + if err != nil { + return err + } + + if !exist { + return ErrCacheItemNotFound + } + + if cacheHit.Expires > 0 { + existedButExpired := getTime().Unix()-cacheHit.CreatedAt >= cacheHit.Expires + if existedButExpired { + err = dc.Delete(ctx, key) // ignore this error since we will return `ErrCacheItemNotFound` anyway + if err != nil { + dc.log.Debug("Deletion of expired key failed: %v", err) + } + return ErrCacheItemNotFound + } + } + + if err = decodeGob(cacheHit.Data, item); err != nil { + return err + } + + return nil + }) + + return item.Val, err } func (dc *databaseCache) Set(ctx context.Context, key string, value interface{}, expire time.Duration) error { @@ -93,34 +95,33 @@ func (dc *databaseCache) Set(ctx context.Context, key string, value interface{}, return err } - session := dc.SQLStore.NewSession(context.Background()) - defer session.Close() + return dc.SQLStore.WithDbSession(ctx, func(session *sqlstore.DBSession) error { + var expiresInSeconds int64 + if expire != 0 { + expiresInSeconds = int64(expire) / int64(time.Second) + } - var expiresInSeconds int64 - if expire != 0 { - expiresInSeconds = int64(expire) / int64(time.Second) - } - - // attempt to insert the key - sql := `INSERT INTO cache_data (cache_key,data,created_at,expires) VALUES(?,?,?,?)` - _, err = session.Exec(sql, key, data, getTime().Unix(), expiresInSeconds) - if err != nil { - // attempt to update if a unique constrain violation or a deadlock (for MySQL) occurs - // if the update fails propagate the error - // which eventually will result in a key that is not finally set - // but since it's a cache does not harm a lot - if dc.SQLStore.Dialect.IsUniqueConstraintViolation(err) || dc.SQLStore.Dialect.IsDeadlock(err) { - sql := `UPDATE cache_data SET data=?, created_at=?, expires=? WHERE cache_key=?` - _, err = session.Exec(sql, data, getTime().Unix(), expiresInSeconds, key) - if err != nil && dc.SQLStore.Dialect.IsDeadlock(err) { - // most probably somebody else is upserting the key - // so it is safe enough not to propagate this error - return nil + // attempt to insert the key + sql := `INSERT INTO cache_data (cache_key,data,created_at,expires) VALUES(?,?,?,?)` + _, err := session.Exec(sql, key, data, getTime().Unix(), expiresInSeconds) + if err != nil { + // attempt to update if a unique constrain violation or a deadlock (for MySQL) occurs + // if the update fails propagate the error + // which eventually will result in a key that is not finally set + // but since it's a cache does not harm a lot + if dc.SQLStore.Dialect.IsUniqueConstraintViolation(err) || dc.SQLStore.Dialect.IsDeadlock(err) { + sql := `UPDATE cache_data SET data=?, created_at=?, expires=? WHERE cache_key=?` + _, err = session.Exec(sql, data, getTime().Unix(), expiresInSeconds, key) + if err != nil && dc.SQLStore.Dialect.IsDeadlock(err) { + // most probably somebody else is upserting the key + // so it is safe enough not to propagate this error + return nil + } } } - } - return err + return err + }) } func (dc *databaseCache) Delete(ctx context.Context, key string) error { diff --git a/pkg/services/accesscontrol/filter_bench_test.go b/pkg/services/accesscontrol/filter_bench_test.go index 7cd5b788d63..1e4e930c07e 100644 --- a/pkg/services/accesscontrol/filter_bench_test.go +++ b/pkg/services/accesscontrol/filter_bench_test.go @@ -43,10 +43,10 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) { require.NoError(b, err) var datasources []datasources.DataSource - sess := store.NewSession(context.Background()) - err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) + err = store.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + return sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) + }) require.NoError(b, err) - sess.Close() require.Len(b, datasources, numPermissions) } } diff --git a/pkg/services/accesscontrol/filter_test.go b/pkg/services/accesscontrol/filter_test.go index cb343a8ae5d..a7bb1d059b1 100644 --- a/pkg/services/accesscontrol/filter_test.go +++ b/pkg/services/accesscontrol/filter_test.go @@ -168,40 +168,41 @@ func TestFilter_Datasources(t *testing.T) { t.Run(tt.desc, func(t *testing.T) { store := sqlstore.InitTestDB(t) - sess := store.NewSession(context.Background()) - defer sess.Close() - - // seed 10 data sources - for i := 1; i <= 10; i++ { - dsStore := dsService.CreateStore(store, log.New("accesscontrol.test")) - err := dsStore.AddDataSource(context.Background(), &datasources.AddDataSourceCommand{Name: fmt.Sprintf("ds:%d", i), Uid: fmt.Sprintf("uid%d", i)}) - require.NoError(t, err) - } - - baseSql := `SELECT data_source.* FROM data_source WHERE` - acFilter, err := accesscontrol.Filter( - &user.SignedInUser{ - OrgID: 1, - Permissions: map[int64]map[string][]string{1: tt.permissions}, - }, - tt.sqlID, - tt.prefix, - tt.actions..., - ) - - if !tt.expectErr { - require.NoError(t, err) - var datasources []datasources.DataSource - err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) - require.NoError(t, err) - - assert.Len(t, datasources, len(tt.expectedDataSources)) - for i, ds := range datasources { - assert.Equal(t, tt.expectedDataSources[i], ds.Name) + err := store.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + // seed 10 data sources + for i := 1; i <= 10; i++ { + dsStore := dsService.CreateStore(store, log.New("accesscontrol.test")) + err := dsStore.AddDataSource(context.Background(), &datasources.AddDataSourceCommand{Name: fmt.Sprintf("ds:%d", i), Uid: fmt.Sprintf("uid%d", i)}) + require.NoError(t, err) } - } else { - require.Error(t, err) - } + + baseSql := `SELECT data_source.* FROM data_source WHERE` + acFilter, err := accesscontrol.Filter( + &user.SignedInUser{ + OrgID: 1, + Permissions: map[int64]map[string][]string{1: tt.permissions}, + }, + tt.sqlID, + tt.prefix, + tt.actions..., + ) + + if !tt.expectErr { + require.NoError(t, err) + var datasources []datasources.DataSource + err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) + require.NoError(t, err) + + assert.Len(t, datasources, len(tt.expectedDataSources)) + for i, ds := range datasources { + assert.Equal(t, tt.expectedDataSources[i], ds.Name) + } + } else { + require.Error(t, err) + } + return nil + }) + require.NoError(t, err) }) } } diff --git a/pkg/services/alerting/store_notification.go b/pkg/services/alerting/store_notification.go index 9423108896a..9d9cca60db1 100644 --- a/pkg/services/alerting/store_notification.go +++ b/pkg/services/alerting/store_notification.go @@ -58,7 +58,9 @@ func (ss *sqlStore) DeleteAlertNotification(ctx context.Context, cmd *models.Del func (ss *sqlStore) DeleteAlertNotificationWithUid(ctx context.Context, cmd *models.DeleteAlertNotificationWithUidCommand) error { existingNotification := &models.GetAlertNotificationsWithUidQuery{OrgId: cmd.OrgId, Uid: cmd.Uid} - if err := getAlertNotificationWithUidInternal(ctx, existingNotification, ss.db.NewSession(ctx)); err != nil { + if err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return getAlertNotificationWithUidInternal(ctx, existingNotification, sess) + }); err != nil { return err } @@ -79,7 +81,9 @@ func (ss *sqlStore) DeleteAlertNotificationWithUid(ctx context.Context, cmd *mod } func (ss *sqlStore) GetAlertNotifications(ctx context.Context, query *models.GetAlertNotificationsQuery) error { - return getAlertNotificationInternal(ctx, query, ss.db.NewSession(ctx)) + return ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return getAlertNotificationInternal(ctx, query, sess) + }) } func (ss *sqlStore) GetAlertNotificationUidWithId(ctx context.Context, query *models.GetAlertNotificationUidQuery) error { @@ -90,8 +94,9 @@ func (ss *sqlStore) GetAlertNotificationUidWithId(ctx context.Context, query *mo return nil } - err := getAlertNotificationUidInternal(ctx, query, ss.db.NewSession(ctx)) - if err != nil { + if err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return getAlertNotificationUidInternal(ctx, query, sess) + }); err != nil { return err } @@ -105,7 +110,9 @@ func newAlertNotificationUidCacheKey(orgID, notificationId int64) string { } func (ss *sqlStore) GetAlertNotificationsWithUid(ctx context.Context, query *models.GetAlertNotificationsWithUidQuery) error { - return getAlertNotificationWithUidInternal(ctx, query, ss.db.NewSession(ctx)) + return ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return getAlertNotificationWithUidInternal(ctx, query, sess) + }) } func (ss *sqlStore) GetAllAlertNotifications(ctx context.Context, query *models.GetAllAlertNotificationsQuery) error { @@ -444,7 +451,9 @@ func (ss *sqlStore) UpdateAlertNotification(ctx context.Context, cmd *models.Upd func (ss *sqlStore) UpdateAlertNotificationWithUid(ctx context.Context, cmd *models.UpdateAlertNotificationWithUidCommand) error { getAlertNotificationWithUidQuery := &models.GetAlertNotificationsWithUidQuery{OrgId: cmd.OrgId, Uid: cmd.Uid} - if err := getAlertNotificationWithUidInternal(ctx, getAlertNotificationWithUidQuery, ss.db.NewSession(ctx)); err != nil { + if err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return getAlertNotificationWithUidInternal(ctx, getAlertNotificationWithUidQuery, sess) + }); err != nil { return err } diff --git a/pkg/services/annotations/annotationsimpl/cleanup_test.go b/pkg/services/annotations/annotationsimpl/cleanup_test.go index 844d7855157..baaaeac5a7c 100644 --- a/pkg/services/annotations/annotationsimpl/cleanup_test.go +++ b/pkg/services/annotations/annotationsimpl/cleanup_test.go @@ -132,57 +132,62 @@ func TestOldAnnotationsAreDeletedFirst(t *testing.T) { Created: time.Now().AddDate(-10, 0, -10).UnixNano() / int64(time.Millisecond), } - session := fakeSQL.NewSession(context.Background()) - defer session.Close() + err := fakeSQL.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + _, err := sess.Insert(a) + require.NoError(t, err, "cannot insert annotation") + _, err = sess.Insert(a) + require.NoError(t, err, "cannot insert annotation") - _, err := session.Insert(a) - require.NoError(t, err, "cannot insert annotation") - _, err = session.Insert(a) - require.NoError(t, err, "cannot insert annotation") + a.AlertId = 20 + _, err = sess.Insert(a) + require.NoError(t, err, "cannot insert annotation") - a.AlertId = 20 - _, err = session.Insert(a) - require.NoError(t, err, "cannot insert annotation") + // run the clean up task to keep one annotation. + cfg := setting.NewCfg() + cfg.AnnotationCleanupJobBatchSize = 1 + cleaner := &xormRepositoryImpl{cfg: cfg, log: log.New("test-logger"), db: fakeSQL} + _, err = cleaner.CleanAnnotations(context.Background(), setting.AnnotationCleanupSettings{MaxCount: 1}, alertAnnotationType) + require.NoError(t, err) - // run the clean up task to keep one annotation. - cfg := setting.NewCfg() - cfg.AnnotationCleanupJobBatchSize = 1 - cleaner := &xormRepositoryImpl{cfg: cfg, log: log.New("test-logger"), db: fakeSQL} - _, err = cleaner.CleanAnnotations(context.Background(), setting.AnnotationCleanupSettings{MaxCount: 1}, alertAnnotationType) + // assert that the last annotations were kept + countNew, err := sess.Where("alert_id = 20").Count(&annotations.Item{}) + require.NoError(t, err) + require.Equal(t, int64(1), countNew, "the last annotations should be kept") + + countOld, err := sess.Where("alert_id = 10").Count(&annotations.Item{}) + require.NoError(t, err) + require.Equal(t, int64(0), countOld, "the two first annotations should have been deleted") + + return nil + }) require.NoError(t, err) - - // assert that the last annotations were kept - countNew, err := session.Where("alert_id = 20").Count(&annotations.Item{}) - require.NoError(t, err) - require.Equal(t, int64(1), countNew, "the last annotations should be kept") - - countOld, err := session.Where("alert_id = 10").Count(&annotations.Item{}) - require.NoError(t, err) - require.Equal(t, int64(0), countOld, "the two first annotations should have been deleted") } func assertAnnotationCount(t *testing.T, fakeSQL *sqlstore.SQLStore, sql string, expectedCount int64) { t.Helper() - session := fakeSQL.NewSession(context.Background()) - defer session.Close() - count, err := session.Where(sql).Count(&annotations.Item{}) + err := fakeSQL.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + count, err := sess.Where(sql).Count(&annotations.Item{}) + require.NoError(t, err) + require.Equal(t, expectedCount, count) + return nil + }) require.NoError(t, err) - require.Equal(t, expectedCount, count) } func assertAnnotationTagCount(t *testing.T, fakeSQL *sqlstore.SQLStore, expectedCount int64) { t.Helper() - session := fakeSQL.NewSession(context.Background()) - defer session.Close() - - count, err := session.SQL("select count(*) from annotation_tag").Count() + err := fakeSQL.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + count, err := sess.SQL("select count(*) from annotation_tag").Count() + require.NoError(t, err) + require.Equal(t, expectedCount, count) + return nil + }) require.NoError(t, err) - require.Equal(t, expectedCount, count) } -func createTestAnnotations(t *testing.T, sqlstore *sqlstore.SQLStore, expectedCount int, oldAnnotations int) { +func createTestAnnotations(t *testing.T, store *sqlstore.SQLStore, expectedCount int, oldAnnotations int) { t.Helper() cutoffDate := time.Now() @@ -216,16 +221,19 @@ func createTestAnnotations(t *testing.T, sqlstore *sqlstore.SQLStore, expectedCo a.Created = cutoffDate.AddDate(-10, 0, -10).UnixNano() / int64(time.Millisecond) } - _, err := sqlstore.NewSession(context.Background()).Insert(a) - require.NoError(t, err, "should be able to save annotation", err) + err := store.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + _, err := sess.Insert(a) + require.NoError(t, err, "should be able to save annotation", err) - // mimick the SQL annotation Save logic by writing records to the annotation_tag table - // we need to ensure they get deleted when we clean up annotations - sess := sqlstore.NewSession(context.Background()) - for tagID := range []int{1, 2} { - _, err = sess.Exec("INSERT INTO annotation_tag (annotation_id, tag_id) VALUES(?,?)", a.Id, tagID) - require.NoError(t, err, "should be able to save annotation tag ID", err) - } + // mimick the SQL annotation Save logic by writing records to the annotation_tag table + // we need to ensure they get deleted when we clean up annotations + for tagID := range []int{1, 2} { + _, err = sess.Exec("INSERT INTO annotation_tag (annotation_id, tag_id) VALUES(?,?)", a.Id, tagID) + require.NoError(t, err, "should be able to save annotation tag ID", err) + } + return err + }) + require.NoError(t, err) } } diff --git a/pkg/services/auth/auth_token_test.go b/pkg/services/auth/auth_token_test.go index 938aca66bae..2f9a46bc142 100644 --- a/pkg/services/auth/auth_token_test.go +++ b/pkg/services/auth/auth_token_test.go @@ -566,40 +566,54 @@ type testContext struct { } func (c *testContext) getAuthTokenByID(id int64) (*userAuthToken, error) { - sess := c.sqlstore.NewSession(context.Background()) - var t userAuthToken - found, err := sess.ID(id).Get(&t) - if err != nil || !found { - return nil, err - } + var res *userAuthToken + err := c.sqlstore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + var t userAuthToken + found, err := sess.ID(id).Get(&t) + if err != nil || !found { + return err + } - return &t, nil + res = &t + return nil + }) + + return res, err } func (c *testContext) markAuthTokenAsSeen(id int64) (bool, error) { - sess := c.sqlstore.NewSession(context.Background()) - res, err := sess.Exec("UPDATE user_auth_token SET auth_token_seen = ? WHERE id = ?", c.sqlstore.Dialect.BooleanStr(true), id) - if err != nil { - return false, err - } + hasRowsAffected := false + err := c.sqlstore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + res, err := sess.Exec("UPDATE user_auth_token SET auth_token_seen = ? WHERE id = ?", c.sqlstore.Dialect.BooleanStr(true), id) + if err != nil { + return err + } - rowsAffected, err := res.RowsAffected() - if err != nil { - return false, err - } - return rowsAffected == 1, nil + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + hasRowsAffected = rowsAffected == 1 + return nil + }) + return hasRowsAffected, err } func (c *testContext) updateRotatedAt(id, rotatedAt int64) (bool, error) { - sess := c.sqlstore.NewSession(context.Background()) - res, err := sess.Exec("UPDATE user_auth_token SET rotated_at = ? WHERE id = ?", rotatedAt, id) - if err != nil { - return false, err - } + hasRowsAffected := false + err := c.sqlstore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + res, err := sess.Exec("UPDATE user_auth_token SET rotated_at = ? WHERE id = ?", rotatedAt, id) + if err != nil { + return err + } - rowsAffected, err := res.RowsAffected() - if err != nil { - return false, err - } - return rowsAffected == 1, nil + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + hasRowsAffected = rowsAffected == 1 + return nil + }) + return hasRowsAffected, err } diff --git a/pkg/services/auth/token_cleanup_test.go b/pkg/services/auth/token_cleanup_test.go index f611701ff18..dad055ef743 100644 --- a/pkg/services/auth/token_cleanup_test.go +++ b/pkg/services/auth/token_cleanup_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/stretchr/testify/require" ) @@ -21,8 +22,12 @@ func TestUserAuthTokenCleanup(t *testing.T) { insertToken := func(ctx *testContext, token string, prev string, createdAt, rotatedAt int64) { ut := userAuthToken{AuthToken: token, PrevAuthToken: prev, CreatedAt: createdAt, RotatedAt: rotatedAt, UserAgent: "", ClientIp: ""} - _, err := ctx.sqlstore.NewSession(context.Background()).Insert(&ut) - require.Nil(t, err) + err := ctx.sqlstore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error { + _, err := sess.Insert(&ut) + require.Nil(t, err) + return nil + }) + require.NoError(t, err) } now := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC) diff --git a/pkg/services/ngalert/store/image_test.go b/pkg/services/ngalert/store/image_test.go index 2680ba67950..58510f9d4ab 100644 --- a/pkg/services/ngalert/store/image_test.go +++ b/pkg/services/ngalert/store/image_test.go @@ -11,6 +11,7 @@ import ( "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/store" "github.com/grafana/grafana/pkg/services/ngalert/tests" + "github.com/grafana/grafana/pkg/services/sqlstore" ) func TestIntegrationSaveAndGetImage(t *testing.T) { @@ -168,30 +169,32 @@ func TestIntegrationDeleteExpiredImages(t *testing.T) { image2 := models.Image{URL: "https://example.com/example.png"} require.NoError(t, dbstore.SaveImage(ctx, &image2)) - s := dbstore.SQLStore.NewSession(ctx) - t.Cleanup(s.Close) + err := dbstore.SQLStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + // should return both images + var result1, result2 models.Image + ok, err := sess.Where("token = ?", image1.Token).Get(&result1) + require.NoError(t, err) + assert.True(t, ok) + ok, err = sess.Where("token = ?", image2.Token).Get(&result2) + require.NoError(t, err) + assert.True(t, ok) - // should return both images - var result1, result2 models.Image - ok, err := s.Where("token = ?", image1.Token).Get(&result1) - require.NoError(t, err) - assert.True(t, ok) - ok, err = s.Where("token = ?", image2.Token).Get(&result2) - require.NoError(t, err) - assert.True(t, ok) + // should delete expired image + image1.ExpiresAt = time.Now().Add(-time.Second) + require.NoError(t, dbstore.SaveImage(ctx, &image1)) + n, err := dbstore.DeleteExpiredImages(ctx) + require.NoError(t, err) + assert.Equal(t, int64(1), n) - // should delete expired image - image1.ExpiresAt = time.Now().Add(-time.Second) - require.NoError(t, dbstore.SaveImage(ctx, &image1)) - n, err := dbstore.DeleteExpiredImages(ctx) - require.NoError(t, err) - assert.Equal(t, int64(1), n) + // should return just the second image + ok, err = sess.Where("token = ?", image1.Token).Get(&result1) + require.NoError(t, err) + assert.False(t, ok) + ok, err = sess.Where("token = ?", image2.Token).Get(&result2) + require.NoError(t, err) + assert.True(t, ok) - // should return just the second image - ok, err = s.Where("token = ?", image1.Token).Get(&result1) + return nil + }) require.NoError(t, err) - assert.False(t, ok) - ok, err = s.Where("token = ?", image2.Token).Get(&result2) - require.NoError(t, err) - assert.True(t, ok) } diff --git a/pkg/services/secrets/database/database.go b/pkg/services/secrets/database/database.go index 7a88a9f4862..348122bd049 100644 --- a/pkg/services/secrets/database/database.go +++ b/pkg/services/secrets/database/database.go @@ -126,7 +126,9 @@ func (ss *SecretsStoreImpl) ReEncryptDataKeys( currProvider secrets.ProviderID, ) error { keys := make([]*secrets.DataKey, 0) - if err := ss.sqlStore.NewSession(ctx).Table(dataKeysTable).Find(&keys); err != nil { + if err := ss.sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(dataKeysTable).Find(&keys) + }); err != nil { return err } diff --git a/pkg/services/secrets/migrator/migrator.go b/pkg/services/secrets/migrator/migrator.go index 07f29e84d8c..a87838b92b9 100644 --- a/pkg/services/secrets/migrator/migrator.go +++ b/pkg/services/secrets/migrator/migrator.go @@ -104,8 +104,10 @@ func (m *SecretsMigrator) RollBackSecrets(ctx context.Context) (bool, error) { return false, nil } - _, sqlErr := m.sqlStore.NewSession(ctx).Exec("DELETE FROM data_keys") - if sqlErr != nil { + if sqlErr := m.sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + _, err := sess.Exec("DELETE FROM data_keys") + return err + }); sqlErr != nil { logger.Warn("Error while cleaning up data keys table...", "error", sqlErr) return false, nil } diff --git a/pkg/services/secrets/migrator/reencrypt.go b/pkg/services/secrets/migrator/reencrypt.go index 80a8adb6cde..10c2ecca596 100644 --- a/pkg/services/secrets/migrator/reencrypt.go +++ b/pkg/services/secrets/migrator/reencrypt.go @@ -18,7 +18,9 @@ func (s simpleSecret) reencrypt(ctx context.Context, secretsSrv *manager.Secrets Secret []byte } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to re-encrypt", "table", s.tableName) return false } @@ -72,7 +74,9 @@ func (s b64Secret) reencrypt(ctx context.Context, secretsSrv *manager.SecretsSer Secret string } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to re-encrypt", "table", s.tableName) return false } @@ -140,7 +144,9 @@ func (s jsonSecret) reencrypt(ctx context.Context, secretsSrv *manager.SecretsSe SecureJsonData map[string][]byte } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Cols("id", "secure_json_data").Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Cols("id", "secure_json_data").Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to re-encrypt", "table", s.tableName) return false } @@ -199,7 +205,9 @@ func (s alertingSecret) reencrypt(ctx context.Context, secretsSrv *manager.Secre } selectSQL := "SELECT id, alertmanager_configuration FROM alert_configuration" - if err := sqlStore.NewSession(ctx).SQL(selectSQL).Find(&results); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.SQL(selectSQL).Find(&results) + }); err != nil { logger.Warn("Could not find any alert_configuration secret to re-encrypt") return false } diff --git a/pkg/services/secrets/migrator/rollback.go b/pkg/services/secrets/migrator/rollback.go index 7668ae8e801..2fa396b5d83 100644 --- a/pkg/services/secrets/migrator/rollback.go +++ b/pkg/services/secrets/migrator/rollback.go @@ -24,7 +24,9 @@ func (s simpleSecret) rollback( Secret []byte } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to roll back", "table", s.tableName) return true } @@ -82,7 +84,9 @@ func (s b64Secret) rollback( Secret string } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Select(fmt.Sprintf("id, %s as secret", s.columnName)).Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to roll back", "table", s.tableName) return true } @@ -154,7 +158,9 @@ func (s jsonSecret) rollback( SecureJsonData map[string][]byte } - if err := sqlStore.NewSession(ctx).Table(s.tableName).Cols("id", "secure_json_data").Find(&rows); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.Table(s.tableName).Cols("id", "secure_json_data").Find(&rows) + }); err != nil { logger.Warn("Could not find any secret to roll back", "table", s.tableName) return true } @@ -217,7 +223,9 @@ func (s alertingSecret) rollback( } selectSQL := "SELECT id, alertmanager_configuration FROM alert_configuration" - if err := sqlStore.NewSession(ctx).SQL(selectSQL).Find(&results); err != nil { + if err := sqlStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + return sess.SQL(selectSQL).Find(&results) + }); err != nil { logger.Warn("Could not find any alert_configuration secret to roll back") return true } diff --git a/pkg/services/sqlstore/db/db.go b/pkg/services/sqlstore/db/db.go index 8af1321bd00..85100f924a4 100644 --- a/pkg/services/sqlstore/db/db.go +++ b/pkg/services/sqlstore/db/db.go @@ -12,7 +12,7 @@ import ( type DB interface { WithTransactionalDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error WithDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error - NewSession(ctx context.Context) *sqlstore.DBSession + WithNewDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error GetDialect() migrator.Dialect GetDBType() core.DbType GetSqlxSession() *session.SessionDB diff --git a/pkg/services/sqlstore/db/dbtest/dbtest.go b/pkg/services/sqlstore/db/dbtest/dbtest.go index c8013cb4f49..d069a096d30 100644 --- a/pkg/services/sqlstore/db/dbtest/dbtest.go +++ b/pkg/services/sqlstore/db/dbtest/dbtest.go @@ -21,3 +21,7 @@ func (f *FakeDB) WithTransactionalDbSession(ctx context.Context, callback sqlsto func (f *FakeDB) WithDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error { return f.ExpectedError } + +func (f *FakeDB) WithNewDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error { + return f.ExpectedError +} diff --git a/pkg/services/sqlstore/mockstore/mockstore.go b/pkg/services/sqlstore/mockstore/mockstore.go index 48941e14e13..c493bfc8046 100644 --- a/pkg/services/sqlstore/mockstore/mockstore.go +++ b/pkg/services/sqlstore/mockstore/mockstore.go @@ -222,6 +222,10 @@ func (m *SQLStoreMock) WithDbSession(ctx context.Context, callback sqlstore.DBTr return m.ExpectedError } +func (m *SQLStoreMock) WithNewDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error { + return m.ExpectedError +} + func (m *SQLStoreMock) GetOrgQuotaByTarget(ctx context.Context, query *models.GetOrgQuotaByTargetQuery) error { return m.ExpectedError } diff --git a/pkg/services/sqlstore/session.go b/pkg/services/sqlstore/session.go index 9289acb66ac..b8a3b9ccc56 100644 --- a/pkg/services/sqlstore/session.go +++ b/pkg/services/sqlstore/session.go @@ -27,13 +27,6 @@ func (sess *DBSession) PublishAfterCommit(msg interface{}) { sess.events = append(sess.events, msg) } -// NewSession returns a new DBSession -func (ss *SQLStore) NewSession(ctx context.Context) *DBSession { - sess := &DBSession{Session: ss.engine.NewSession()} - sess.Session = sess.Session.Context(ctx) - return sess -} - func startSessionOrUseExisting(ctx context.Context, engine *xorm.Engine, beginTran bool) (*DBSession, bool, error) { value := ctx.Value(ContextSessionKey{}) var sess *DBSession @@ -55,14 +48,24 @@ func startSessionOrUseExisting(ctx context.Context, engine *xorm.Engine, beginTr } newSess.Session = newSess.Session.Context(ctx) + return newSess, true, nil } -// WithDbSession calls the callback with a session. +// WithDbSession calls the callback with the session in the context (if exists). +// Otherwise it creates a new one that is closed upon completion. +// A session is stored in the context if sqlstore.InTransaction() has been been previously called with the same context (and it's not committed/rolledback yet). func (ss *SQLStore) WithDbSession(ctx context.Context, callback DBTransactionFunc) error { return withDbSession(ctx, ss.engine, callback) } +// WithNewDbSession calls the callback with a new session that is closed upon completion. +func (ss *SQLStore) WithNewDbSession(ctx context.Context, callback DBTransactionFunc) error { + sess := &DBSession{Session: ss.engine.NewSession(), transactionOpen: false} + defer sess.Close() + return callback(sess) +} + func withDbSession(ctx context.Context, engine *xorm.Engine, callback DBTransactionFunc) error { sess, isNew, err := startSessionOrUseExisting(ctx, engine, false) if err != nil { diff --git a/pkg/services/sqlstore/store.go b/pkg/services/sqlstore/store.go index 23412cf1d25..fab61c5fbd9 100644 --- a/pkg/services/sqlstore/store.go +++ b/pkg/services/sqlstore/store.go @@ -30,8 +30,8 @@ type Store interface { GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error UpdateUserPermissions(userID int64, isAdmin bool) error SetUserHelpFlag(ctx context.Context, cmd *models.SetUserHelpFlagCommand) error - NewSession(ctx context.Context) *DBSession WithDbSession(ctx context.Context, callback DBTransactionFunc) error + WithNewDbSession(ctx context.Context, callback DBTransactionFunc) error GetOrgQuotaByTarget(ctx context.Context, query *models.GetOrgQuotaByTargetQuery) error GetOrgQuotas(ctx context.Context, query *models.GetOrgQuotasQuery) error UpdateOrgQuota(ctx context.Context, cmd *models.UpdateOrgQuotaCmd) error diff --git a/pkg/services/sqlstore/transactions.go b/pkg/services/sqlstore/transactions.go index baf3efd2c94..15a3f2e0ccd 100644 --- a/pkg/services/sqlstore/transactions.go +++ b/pkg/services/sqlstore/transactions.go @@ -20,6 +20,8 @@ func (ss *SQLStore) WithTransactionalDbSession(ctx context.Context, callback DBT return inTransactionWithRetryCtx(ctx, ss.engine, ss.bus, callback, 0) } +// InTransaction starts a transaction and calls the fn +// It stores the session in the context func (ss *SQLStore) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error { return ss.inTransactionWithRetry(ctx, fn, 0) }