Chore: Add context to datasource service (#42294)

* Add context to datasource service

* Adjust wire for ShouldBeReported method

* Replace inTransactionCtx
This commit is contained in:
idafurjes 2021-11-26 18:10:36 +01:00 committed by GitHub
parent 6aa05c5d05
commit 725dbf8d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 69 additions and 64 deletions

View File

@ -24,7 +24,7 @@ var datasourcesLogger = log.New("datasources")
func (hs *HTTPServer) GetDataSources(c *models.ReqContext) response.Response { func (hs *HTTPServer) GetDataSources(c *models.ReqContext) response.Response {
query := models.GetDataSourcesQuery{OrgId: c.OrgId, DataSourceLimit: hs.Cfg.DataSourceLimit} query := models.GetDataSourcesQuery{OrgId: c.OrgId, DataSourceLimit: hs.Cfg.DataSourceLimit}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
return response.Error(500, "Failed to query datasources", err) return response.Error(500, "Failed to query datasources", err)
} }

View File

@ -25,7 +25,7 @@ func (hs *HTTPServer) getFSDataSources(c *models.ReqContext, enabledPlugins Enab
if c.OrgId != 0 { if c.OrgId != 0 {
query := models.GetDataSourcesQuery{OrgId: c.OrgId, DataSourceLimit: hs.Cfg.DataSourceLimit} query := models.GetDataSourcesQuery{OrgId: c.OrgId, DataSourceLimit: hs.Cfg.DataSourceLimit}
err := bus.Dispatch(&query) err := bus.DispatchCtx(c.Req.Context(), &query)
if err != nil { if err != nil {
return nil, err return nil, err
@ -36,7 +36,7 @@ func (hs *HTTPServer) getFSDataSources(c *models.ReqContext, enabledPlugins Enab
Datasources: query.Result, Datasources: query.Result,
} }
if err := bus.Dispatch(&dsFilterQuery); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &dsFilterQuery); err != nil {
if !errors.Is(err, bus.ErrHandlerNotFound) { if !errors.Is(err, bus.ErrHandlerNotFound) {
return nil, err return nil, err
} }

View File

@ -51,7 +51,7 @@ func (noOpUsageStats) RegisterMetricsFunc(_ usagestats.MetricsFunc) {}
func (noOpUsageStats) RegisterSendReportCallback(_ usagestats.SendReportCallbackFunc) {} func (noOpUsageStats) RegisterSendReportCallback(_ usagestats.SendReportCallbackFunc) {}
func (noOpUsageStats) ShouldBeReported(string) bool { return false } func (noOpUsageStats) ShouldBeReported(context.Context, string) bool { return false }
type noOpRouteRegister struct{} type noOpRouteRegister struct{}

View File

@ -29,7 +29,7 @@ func (usm *UsageStatsMock) GetUsageReport(ctx context.Context) (Report, error) {
return Report{Metrics: all}, nil return Report{Metrics: all}, nil
} }
func (usm *UsageStatsMock) ShouldBeReported(_ string) bool { func (usm *UsageStatsMock) ShouldBeReported(_ context.Context, _ string) bool {
return true return true
} }

View File

@ -23,5 +23,5 @@ type Service interface {
GetUsageReport(context.Context) (Report, error) GetUsageReport(context.Context) (Report, error)
RegisterMetricsFunc(MetricsFunc) RegisterMetricsFunc(MetricsFunc)
RegisterSendReportCallback(SendReportCallbackFunc) RegisterSendReportCallback(SendReportCallbackFunc)
ShouldBeReported(string) bool ShouldBeReported(context.Context, string) bool
} }

View File

@ -122,7 +122,7 @@ func (uss *UsageStats) GetUsageReport(ctx context.Context) (usagestats.Report, e
// as sending that name could be sensitive information // as sending that name could be sensitive information
dsOtherCount := 0 dsOtherCount := 0
for _, dsStat := range dsStats.Result { for _, dsStat := range dsStats.Result {
if uss.ShouldBeReported(dsStat.Type) { if uss.ShouldBeReported(ctx, dsStat.Type) {
metrics["stats.ds."+dsStat.Type+".count"] = dsStat.Count metrics["stats.ds."+dsStat.Type+".count"] = dsStat.Count
} else { } else {
dsOtherCount += dsStat.Count dsOtherCount += dsStat.Count
@ -131,7 +131,7 @@ func (uss *UsageStats) GetUsageReport(ctx context.Context) (usagestats.Report, e
metrics["stats.ds.other.count"] = dsOtherCount metrics["stats.ds.other.count"] = dsOtherCount
esDataSourcesQuery := models.GetDataSourcesByTypeQuery{Type: models.DS_ES} esDataSourcesQuery := models.GetDataSourcesByTypeQuery{Type: models.DS_ES}
if err := uss.Bus.Dispatch(&esDataSourcesQuery); err != nil { if err := uss.Bus.DispatchCtx(ctx, &esDataSourcesQuery); err != nil {
uss.log.Error("Failed to get elasticsearch json data", "error", err) uss.log.Error("Failed to get elasticsearch json data", "error", err)
return report, err return report, err
} }
@ -170,7 +170,7 @@ func (uss *UsageStats) GetUsageReport(ctx context.Context) (usagestats.Report, e
access := strings.ToLower(dsAccessStat.Access) access := strings.ToLower(dsAccessStat.Access)
if uss.ShouldBeReported(dsAccessStat.Type) { if uss.ShouldBeReported(ctx, dsAccessStat.Type) {
metrics["stats.ds_access."+dsAccessStat.Type+"."+access+".count"] = dsAccessStat.Count metrics["stats.ds_access."+dsAccessStat.Type+"."+access+".count"] = dsAccessStat.Count
} else { } else {
old := dsAccessOtherCount[access] old := dsAccessOtherCount[access]
@ -329,8 +329,8 @@ func (uss *UsageStats) updateTotalStats(ctx context.Context) {
} }
} }
func (uss *UsageStats) ShouldBeReported(dsType string) bool { func (uss *UsageStats) ShouldBeReported(ctx context.Context, dsType string) bool {
ds, exists := uss.pluginStore.Plugin(context.TODO(), dsType) ds, exists := uss.pluginStore.Plugin(ctx, dsType)
if !exists { if !exists {
return false return false
} }

View File

@ -33,7 +33,7 @@ func (e *AlertEngine) QueryUsageStats(ctx context.Context) (*UsageStats, error)
return nil, err return nil, err
} }
dsUsage, err := e.mapRulesToUsageStats(cmd.Result) dsUsage, err := e.mapRulesToUsageStats(ctx, cmd.Result)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -43,7 +43,7 @@ func (e *AlertEngine) QueryUsageStats(ctx context.Context) (*UsageStats, error)
}, nil }, nil
} }
func (e *AlertEngine) mapRulesToUsageStats(rules []*models.Alert) (DatasourceAlertUsage, error) { func (e *AlertEngine) mapRulesToUsageStats(ctx context.Context, rules []*models.Alert) (DatasourceAlertUsage, error) {
// map of datasourceId type and frequency // map of datasourceId type and frequency
typeCount := map[int64]int{} typeCount := map[int64]int{}
for _, a := range rules { for _, a := range rules {
@ -63,7 +63,7 @@ func (e *AlertEngine) mapRulesToUsageStats(rules []*models.Alert) (DatasourceAle
result := map[string]int{} result := map[string]int{}
for k, v := range typeCount { for k, v := range typeCount {
query := &models.GetDataSourceQuery{Id: k} query := &models.GetDataSourceQuery{Id: k}
err := e.Bus.DispatchCtx(context.TODO(), query) err := e.Bus.DispatchCtx(ctx, query)
if err != nil { if err != nil {
return map[string]int{}, nil return map[string]int{}, nil
} }

View File

@ -257,7 +257,7 @@ func (e *AlertEngine) registerUsageMetrics() {
metrics := map[string]interface{}{} metrics := map[string]interface{}{}
for dsType, usageCount := range alertingUsageStats.DatasourceUsage { for dsType, usageCount := range alertingUsageStats.DatasourceUsage {
if e.usageStatsService.ShouldBeReported(dsType) { if e.usageStatsService.ShouldBeReported(ctx, dsType) {
metrics[fmt.Sprintf("stats.alerting.ds.%s.count", dsType)] = usageCount metrics[fmt.Sprintf("stats.alerting.ds.%s.count", dsType)] = usageCount
} else { } else {
alertingOtherCount += usageCount alertingOtherCount += usageCount

View File

@ -30,7 +30,7 @@ func NewDashAlertExtractor(dash *models.Dashboard, orgID int64, user *models.Sig
} }
} }
func (e *DashAlertExtractor) lookupQueryDataSource(panel *simplejson.Json, panelQuery *simplejson.Json) (*models.DataSource, error) { func (e *DashAlertExtractor) lookupQueryDataSource(ctx context.Context, panel *simplejson.Json, panelQuery *simplejson.Json) (*models.DataSource, error) {
dsName := "" dsName := ""
dsUid := "" dsUid := ""
@ -48,14 +48,14 @@ func (e *DashAlertExtractor) lookupQueryDataSource(panel *simplejson.Json, panel
if dsName == "" && dsUid == "" { if dsName == "" && dsUid == "" {
query := &models.GetDefaultDataSourceQuery{OrgId: e.OrgID} query := &models.GetDefaultDataSourceQuery{OrgId: e.OrgID}
if err := bus.DispatchCtx(context.TODO(), query); err != nil { if err := bus.DispatchCtx(ctx, query); err != nil {
return nil, err return nil, err
} }
return query.Result, nil return query.Result, nil
} }
query := &models.GetDataSourceQuery{Name: dsName, Uid: dsUid, OrgId: e.OrgID} query := &models.GetDataSourceQuery{Name: dsName, Uid: dsUid, OrgId: e.OrgID}
if err := bus.DispatchCtx(context.TODO(), query); err != nil { if err := bus.DispatchCtx(ctx, query); err != nil {
return nil, err return nil, err
} }
@ -174,7 +174,7 @@ func (e *DashAlertExtractor) getAlertFromPanels(ctx context.Context, jsonWithPan
return nil, ValidationError{Reason: reason} return nil, ValidationError{Reason: reason}
} }
datasource, err := e.lookupQueryDataSource(panel, panelQuery) datasource, err := e.lookupQueryDataSource(ctx, panel, panelQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,7 +184,7 @@ func (e *DashAlertExtractor) getAlertFromPanels(ctx context.Context, jsonWithPan
Datasources: []*models.DataSource{datasource}, Datasources: []*models.DataSource{datasource},
} }
if err := bus.Dispatch(&dsFilterQuery); err != nil { if err := bus.DispatchCtx(ctx, &dsFilterQuery); err != nil {
if !errors.Is(err, bus.ErrHandlerNotFound) { if !errors.Is(err, bus.ErrHandlerNotFound) {
return nil, err return nil, err
} }

View File

@ -62,13 +62,13 @@ func ProvideService(bus bus.Bus, store *sqlstore.SQLStore, secretsService secret
}, },
} }
s.Bus.AddHandler(s.GetDataSources) s.Bus.AddHandlerCtx(s.GetDataSources)
s.Bus.AddHandler(s.GetDataSourcesByType) s.Bus.AddHandlerCtx(s.GetDataSourcesByType)
s.Bus.AddHandlerCtx(s.GetDataSource) s.Bus.AddHandlerCtx(s.GetDataSource)
s.Bus.AddHandlerCtx(s.AddDataSource) s.Bus.AddHandlerCtx(s.AddDataSource)
s.Bus.AddHandlerCtx(s.DeleteDataSource) s.Bus.AddHandlerCtx(s.DeleteDataSource)
s.Bus.AddHandlerCtx(s.UpdateDataSource) s.Bus.AddHandlerCtx(s.UpdateDataSource)
s.Bus.AddHandler(s.GetDefaultDataSource) s.Bus.AddHandlerCtx(s.GetDefaultDataSource)
return s return s
} }
@ -77,12 +77,12 @@ func (s *Service) GetDataSource(ctx context.Context, query *models.GetDataSource
return s.SQLStore.GetDataSource(ctx, query) return s.SQLStore.GetDataSource(ctx, query)
} }
func (s *Service) GetDataSources(query *models.GetDataSourcesQuery) error { func (s *Service) GetDataSources(ctx context.Context, query *models.GetDataSourcesQuery) error {
return s.SQLStore.GetDataSources(query) return s.SQLStore.GetDataSources(ctx, query)
} }
func (s *Service) GetDataSourcesByType(query *models.GetDataSourcesByTypeQuery) error { func (s *Service) GetDataSourcesByType(ctx context.Context, query *models.GetDataSourcesByTypeQuery) error {
return s.SQLStore.GetDataSourcesByType(query) return s.SQLStore.GetDataSourcesByType(ctx, query)
} }
func (s *Service) AddDataSource(ctx context.Context, cmd *models.AddDataSourceCommand) error { func (s *Service) AddDataSource(ctx context.Context, cmd *models.AddDataSourceCommand) error {
@ -109,8 +109,8 @@ func (s *Service) UpdateDataSource(ctx context.Context, cmd *models.UpdateDataSo
return s.SQLStore.UpdateDataSource(ctx, cmd) return s.SQLStore.UpdateDataSource(ctx, cmd)
} }
func (s *Service) GetDefaultDataSource(query *models.GetDefaultDataSourceQuery) error { func (s *Service) GetDefaultDataSource(ctx context.Context, query *models.GetDefaultDataSourceQuery) error {
return s.SQLStore.GetDefaultDataSource(query) return s.SQLStore.GetDefaultDataSource(ctx, query)
} }
func (s *Service) GetHTTPClient(ds *models.DataSource, provider httpclient.Provider) (*http.Client, error) { func (s *Service) GetHTTPClient(ds *models.DataSource, provider httpclient.Provider) (*http.Client, error) {

View File

@ -40,33 +40,37 @@ func (ss *SQLStore) GetDataSource(ctx context.Context, query *models.GetDataSour
}) })
} }
func (ss *SQLStore) GetDataSources(query *models.GetDataSourcesQuery) error { func (ss *SQLStore) GetDataSources(ctx context.Context, query *models.GetDataSourcesQuery) error {
var sess *xorm.Session var sess *xorm.Session
return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
if query.DataSourceLimit <= 0 { if query.DataSourceLimit <= 0 {
sess = x.Where("org_id=?", query.OrgId).Asc("name") sess = dbSess.Where("org_id=?", query.OrgId).Asc("name")
} else { } else {
sess = x.Limit(query.DataSourceLimit, 0).Where("org_id=?", query.OrgId).Asc("name") sess = dbSess.Limit(query.DataSourceLimit, 0).Where("org_id=?", query.OrgId).Asc("name")
} }
query.Result = make([]*models.DataSource, 0) query.Result = make([]*models.DataSource, 0)
return sess.Find(&query.Result) return sess.Find(&query.Result)
})
} }
// GetDataSourcesByType returns all datasources for a given type or an error if the specified type is an empty string // GetDataSourcesByType returns all datasources for a given type or an error if the specified type is an empty string
func (ss *SQLStore) GetDataSourcesByType(query *models.GetDataSourcesByTypeQuery) error { func (ss *SQLStore) GetDataSourcesByType(ctx context.Context, query *models.GetDataSourcesByTypeQuery) error {
if query.Type == "" { if query.Type == "" {
return fmt.Errorf("datasource type cannot be empty") return fmt.Errorf("datasource type cannot be empty")
} }
query.Result = make([]*models.DataSource, 0) query.Result = make([]*models.DataSource, 0)
return x.Where("type=?", query.Type).Asc("id").Find(&query.Result) return ss.WithDbSession(ctx, func(sess *DBSession) error {
return sess.Where("type=?", query.Type).Asc("id").Find(&query.Result)
})
} }
// GetDefaultDataSource is used to get the default datasource of organization // GetDefaultDataSource is used to get the default datasource of organization
func (ss *SQLStore) GetDefaultDataSource(query *models.GetDefaultDataSourceQuery) error { func (ss *SQLStore) GetDefaultDataSource(ctx context.Context, query *models.GetDefaultDataSourceQuery) error {
datasource := models.DataSource{} datasource := models.DataSource{}
return ss.WithDbSession(ctx, func(sess *DBSession) error {
exists, err := x.Where("org_id=? AND is_default=?", query.OrgId, true).Get(&datasource) exists, err := sess.Where("org_id=? AND is_default=?", query.OrgId, true).Get(&datasource)
if !exists { if !exists {
return models.ErrDataSourceNotFound return models.ErrDataSourceNotFound
@ -74,6 +78,7 @@ func (ss *SQLStore) GetDefaultDataSource(query *models.GetDefaultDataSourceQuery
query.Result = &datasource query.Result = &datasource
return err return err
})
} }
// DeleteDataSource removes a datasource by org_id as well as either uid (preferred), id, or name // DeleteDataSource removes a datasource by org_id as well as either uid (preferred), id, or name
@ -194,7 +199,7 @@ func updateIsDefaultFlag(ds *models.DataSource, sess *DBSession) error {
} }
func (ss *SQLStore) UpdateDataSource(ctx context.Context, cmd *models.UpdateDataSourceCommand) error { func (ss *SQLStore) UpdateDataSource(ctx context.Context, cmd *models.UpdateDataSourceCommand) error {
return inTransactionCtx(ctx, func(sess *DBSession) error { return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if cmd.JsonData == nil { if cmd.JsonData == nil {
cmd.JsonData = simplejson.New() cmd.JsonData = simplejson.New()
} }

View File

@ -40,7 +40,7 @@ func TestDataAccess(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -63,7 +63,7 @@ func TestDataAccess(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -97,7 +97,7 @@ func TestDataAccess(t *testing.T) {
sqlStore := InitTestDB(t) sqlStore := InitTestDB(t)
var created *events.DataSourceCreated var created *events.DataSourceCreated
bus.AddEventListener(func(e *events.DataSourceCreated) error { bus.AddEventListenerCtx(func(ctx context.Context, e *events.DataSourceCreated) error {
created = e created = e
return nil return nil
}) })
@ -110,7 +110,7 @@ func TestDataAccess(t *testing.T) {
}, time.Second, time.Millisecond) }, time.Second, time.Millisecond)
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -216,7 +216,7 @@ func TestDataAccess(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(query.Result)) require.Equal(t, 0, len(query.Result))
@ -229,7 +229,7 @@ func TestDataAccess(t *testing.T) {
err := sqlStore.DeleteDataSource(context.Background(), &models.DeleteDataSourceCommand{ID: ds.Id, OrgID: 123123}) err := sqlStore.DeleteDataSource(context.Background(), &models.DeleteDataSourceCommand{ID: ds.Id, OrgID: 123123})
require.NoError(t, err) require.NoError(t, err)
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -241,7 +241,7 @@ func TestDataAccess(t *testing.T) {
ds := initDatasource(sqlStore) ds := initDatasource(sqlStore)
var deleted *events.DataSourceDeleted var deleted *events.DataSourceDeleted
bus.AddEventListener(func(e *events.DataSourceDeleted) error { bus.AddEventListenerCtx(func(ctx context.Context, e *events.DataSourceDeleted) error {
deleted = e deleted = e
return nil return nil
}) })
@ -267,7 +267,7 @@ func TestDataAccess(t *testing.T) {
err := sqlStore.DeleteDataSource(context.Background(), &models.DeleteDataSourceCommand{Name: ds.Name, OrgID: ds.OrgId}) err := sqlStore.DeleteDataSource(context.Background(), &models.DeleteDataSourceCommand{Name: ds.Name, OrgID: ds.OrgId})
require.NoError(t, err) require.NoError(t, err)
err = sqlStore.GetDataSources(&query) err = sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(query.Result)) require.Equal(t, 0, len(query.Result))
@ -291,7 +291,7 @@ func TestDataAccess(t *testing.T) {
} }
query := models.GetDataSourcesQuery{OrgId: 10, DataSourceLimit: datasourceLimit} query := models.GetDataSourcesQuery{OrgId: 10, DataSourceLimit: datasourceLimit}
err := sqlStore.GetDataSources(&query) err := sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, datasourceLimit, len(query.Result)) require.Equal(t, datasourceLimit, len(query.Result))
@ -314,7 +314,7 @@ func TestDataAccess(t *testing.T) {
} }
query := models.GetDataSourcesQuery{OrgId: 10} query := models.GetDataSourcesQuery{OrgId: 10}
err := sqlStore.GetDataSources(&query) err := sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, numberOfDatasource, len(query.Result)) require.Equal(t, numberOfDatasource, len(query.Result))
@ -337,7 +337,7 @@ func TestDataAccess(t *testing.T) {
} }
query := models.GetDataSourcesQuery{OrgId: 10, DataSourceLimit: -1} query := models.GetDataSourcesQuery{OrgId: 10, DataSourceLimit: -1}
err := sqlStore.GetDataSources(&query) err := sqlStore.GetDataSources(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, numberOfDatasource, len(query.Result)) require.Equal(t, numberOfDatasource, len(query.Result))
@ -372,7 +372,7 @@ func TestDataAccess(t *testing.T) {
query := models.GetDataSourcesByTypeQuery{Type: models.DS_ES} query := models.GetDataSourcesByTypeQuery{Type: models.DS_ES}
err = sqlStore.GetDataSourcesByType(&query) err = sqlStore.GetDataSourcesByType(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -383,7 +383,7 @@ func TestDataAccess(t *testing.T) {
query := models.GetDataSourcesByTypeQuery{} query := models.GetDataSourcesByTypeQuery{}
err := sqlStore.GetDataSourcesByType(&query) err := sqlStore.GetDataSourcesByType(context.Background(), &query)
require.Error(t, err) require.Error(t, err)
}) })
@ -408,7 +408,7 @@ func TestGetDefaultDataSource(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
query := models.GetDefaultDataSourceQuery{OrgId: 10} query := models.GetDefaultDataSourceQuery{OrgId: 10}
err = sqlStore.GetDefaultDataSource(&query) err = sqlStore.GetDefaultDataSource(context.Background(), &query)
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, models.ErrDataSourceNotFound)) assert.True(t, errors.Is(err, models.ErrDataSourceNotFound))
}) })
@ -429,7 +429,7 @@ func TestGetDefaultDataSource(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
query := models.GetDefaultDataSourceQuery{OrgId: 10} query := models.GetDefaultDataSourceQuery{OrgId: 10}
err = sqlStore.GetDefaultDataSource(&query) err = sqlStore.GetDefaultDataSource(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "default datasource", query.Result.Name) assert.Equal(t, "default datasource", query.Result.Name)
}) })
@ -437,7 +437,7 @@ func TestGetDefaultDataSource(t *testing.T) {
t.Run("should not return default datasource of other organisation", func(t *testing.T) { t.Run("should not return default datasource of other organisation", func(t *testing.T) {
sqlStore := InitTestDB(t) sqlStore := InitTestDB(t)
query := models.GetDefaultDataSourceQuery{OrgId: 1} query := models.GetDefaultDataSourceQuery{OrgId: 1}
err := sqlStore.GetDefaultDataSource(&query) err := sqlStore.GetDefaultDataSource(context.Background(), &query)
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, models.ErrDataSourceNotFound)) assert.True(t, errors.Is(err, models.ErrDataSourceNotFound))
}) })