Resource server improvements and fixes (#90715)

* cleanup dependencies and improve list method
* Improve Resource Server API, remove unnecessary dependencies
* Reduce the API footprint of ResourceDBInterface and its implementation
* Improve LifecycleHooks to use context
* Improve testing
* reduce API size and improve code
* sqltemplate: add DialectForDriver func and improve naming
* improve lifecycle API
* many small fixes after adding more tests
This commit is contained in:
Diego Augusto Molina 2024-07-22 14:08:30 -03:00 committed by GitHub
parent 5f367f05dc
commit 399d77a0fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 2193 additions and 945 deletions

View File

@ -392,7 +392,7 @@ func TestQueries(t *testing.T) {
expectedQuery := sqltemplate.FormatSQL(rawQuery)
for _, d := range ds {
t.Run(d.Name(), func(t *testing.T) {
t.Run(d.DialectName(), func(t *testing.T) {
// not parallel for the same reason
tc.Data.SetDialect(d)

View File

@ -44,7 +44,7 @@ type entityBridge struct {
}
// Init implements ResourceServer.
func (b *entityBridge) Init() error {
func (b *entityBridge) Init(context.Context) error {
if b.server != nil {
return b.server.Init()
}
@ -52,10 +52,11 @@ func (b *entityBridge) Init() error {
}
// Stop implements ResourceServer.
func (b *entityBridge) Stop() {
func (b *entityBridge) Stop(context.Context) error {
if b.server != nil {
b.server.Stop()
}
return nil
}
// Convert resource key to the entity key

View File

@ -18,10 +18,10 @@ type WriteAccessHooks struct {
type LifecycleHooks interface {
// Called once at initialization
Init() error
Init(context.Context) error
// Stop function -- after calling this, any additional storage functions may error
Stop()
Stop(context.Context) error
}
func (a *WriteAccessHooks) CanWriteFolder(ctx context.Context, user identity.Requester, uid string) error {

View File

@ -15,13 +15,13 @@ var (
type noopService struct{}
// Init implements ResourceServer.
func (n *noopService) Init() error {
func (n *noopService) Init(context.Context) error {
return nil
}
// Stop implements ResourceServer.
func (n *noopService) Stop() {
// nothing
func (n *noopService) Stop(context.Context) error {
return nil
}
// IsHealthy implements ResourceServer.

View File

@ -150,11 +150,11 @@ type server struct {
}
// Init implements ResourceServer.
func (s *server) Init() error {
func (s *server) Init(ctx context.Context) error {
s.once.Do(func() {
// Call lifecycle hooks
if s.lifecycle != nil {
err := s.lifecycle.Init()
err := s.lifecycle.Init(ctx)
if err != nil {
s.initErr = fmt.Errorf("initialize Resource Server: %w", err)
}
@ -172,18 +172,28 @@ func (s *server) Init() error {
return s.initErr
}
func (s *server) Stop() {
func (s *server) Stop(ctx context.Context) error {
s.initErr = fmt.Errorf("service is stopping")
var stopFailed bool
if s.lifecycle != nil {
s.lifecycle.Stop()
err := s.lifecycle.Stop(ctx)
if err != nil {
stopFailed = true
s.initErr = fmt.Errorf("service stopeed with error: %w", err)
}
}
// Stops the streaming
s.cancel()
// mark the value as done
if stopFailed {
return s.initErr
}
s.initErr = fmt.Errorf("service is stopped")
return nil
}
// Old value indicates an update -- otherwise a create
@ -279,7 +289,7 @@ func (s *server) Create(ctx context.Context, req *CreateRequest) (*CreateRespons
ctx, span := s.tracer.Start(ctx, "storage_server.Create")
defer span.End()
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
@ -349,7 +359,7 @@ func (s *server) Update(ctx context.Context, req *UpdateRequest) (*UpdateRespons
ctx, span := s.tracer.Start(ctx, "storage_server.Update")
defer span.End()
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
@ -394,7 +404,7 @@ func (s *server) Delete(ctx context.Context, req *DeleteRequest) (*DeleteRespons
ctx, span := s.tracer.Start(ctx, "storage_server.Delete")
defer span.End()
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
@ -455,7 +465,7 @@ func (s *server) Delete(ctx context.Context, req *DeleteRequest) (*DeleteRespons
}
func (s *server) Read(ctx context.Context, req *ReadRequest) (*ReadResponse, error) {
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
@ -479,7 +489,7 @@ func (s *server) Read(ctx context.Context, req *ReadRequest) (*ReadResponse, err
}
func (s *server) List(ctx context.Context, req *ListRequest) (*ListResponse, error) {
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
@ -508,12 +518,12 @@ func (s *server) initWatcher() error {
}
func (s *server) Watch(req *WatchRequest, srv ResourceStore_WatchServer) error {
if err := s.Init(); err != nil {
ctx := srv.Context()
if err := s.Init(ctx); err != nil {
return err
}
ctx := srv.Context()
// Start listening -- this will buffer any changes that happen while we backfill
stream, err := s.broadcaster.Subscribe(ctx)
if err != nil {
@ -565,7 +575,7 @@ func (s *server) Watch(req *WatchRequest, srv ResourceStore_WatchServer) error {
// History implements ResourceServer.
func (s *server) History(ctx context.Context, req *HistoryRequest) (*HistoryResponse, error) {
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
return s.index.History(ctx, req)
@ -573,7 +583,7 @@ func (s *server) History(ctx context.Context, req *HistoryRequest) (*HistoryResp
// Origin implements ResourceServer.
func (s *server) Origin(ctx context.Context, req *OriginRequest) (*OriginResponse, error) {
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
return s.index.Origin(ctx, req)
@ -581,7 +591,7 @@ func (s *server) Origin(ctx context.Context, req *OriginRequest) (*OriginRespons
// IsHealthy implements ResourceServer.
func (s *server) IsHealthy(ctx context.Context, req *HealthCheckRequest) (*HealthCheckResponse, error) {
if err := s.Init(); err != nil {
if err := s.Init(ctx); err != nil {
return nil, err
}
return s.diagnostics.IsHealthy(ctx, req)

View File

@ -5,130 +5,114 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"text/template"
"sync"
"time"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/protobuf/proto"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/storage/unified/resource"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/dbutil"
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
)
const trace_prefix = "sql.resource."
type backendOptions struct {
DB db.ResourceDBInterface
Tracer trace.Tracer
type Backend interface {
resource.StorageBackend
resource.DiagnosticsServer
resource.LifecycleHooks
}
func NewBackendStore(opts backendOptions) (*backend, error) {
ctx, cancel := context.WithCancel(context.Background())
type BackendOptions struct {
DBProvider db.DBProvider
Tracer trace.Tracer
}
func NewBackend(opts BackendOptions) (Backend, error) {
if opts.DBProvider == nil {
return nil, errors.New("no db provider")
}
if opts.Tracer == nil {
opts.Tracer = noop.NewTracerProvider().Tracer("sql-backend")
}
ctx, cancel := context.WithCancel(context.Background())
return &backend{
db: opts.DB,
log: log.New("sql-resource-server"),
ctx: ctx,
cancel: cancel,
tracer: opts.Tracer,
done: ctx.Done(),
cancel: cancel,
log: log.New("sql-resource-server"),
tracer: opts.Tracer,
dbProvider: opts.DBProvider,
}, nil
}
type backend struct {
log log.Logger
db db.ResourceDBInterface // needed to keep xorm engine in scope
sess *session.SessionDB
dialect migrator.Dialect
ctx context.Context // TODO: remove
cancel context.CancelFunc
tracer trace.Tracer
// server lifecycle
done <-chan struct{}
cancel context.CancelFunc
initOnce sync.Once
initErr error
// o11y
log log.Logger
tracer trace.Tracer
// database
dbProvider db.DBProvider
db db.DB
dialect sqltemplate.Dialect
// watch streaming
//stream chan *resource.WatchEvent
sqlDB db.DB
sqlDialect sqltemplate.Dialect
}
func (b *backend) Init() error {
if b.sess != nil {
return nil
}
func (b *backend) Init(ctx context.Context) error {
b.initOnce.Do(func() {
b.initErr = b.initLocked(ctx)
})
return b.initErr
}
if b.db == nil {
return errors.New("missing db")
}
err := b.db.Init()
func (b *backend) initLocked(ctx context.Context) error {
db, err := b.dbProvider.Init(ctx)
if err != nil {
return err
return fmt.Errorf("initialize resource DB: %w", err)
}
b.db = db
sqlDB, err := b.db.GetDB()
if err != nil {
return err
}
b.sqlDB = sqlDB
driverName := sqlDB.DriverName()
driverName = strings.TrimSuffix(driverName, "WithHooks")
switch driverName {
case db.DriverMySQL:
b.sqlDialect = sqltemplate.MySQL
case db.DriverPostgres:
b.sqlDialect = sqltemplate.PostgreSQL
case db.DriverSQLite, db.DriverSQLite3:
b.sqlDialect = sqltemplate.SQLite
default:
driverName := db.DriverName()
b.dialect = sqltemplate.DialectForDriver(driverName)
if b.dialect == nil {
return fmt.Errorf("no dialect for driver %q", driverName)
}
sess, err := b.db.GetSession()
if err != nil {
return err
}
engine, err := b.db.GetEngine()
if err != nil {
return err
}
b.sess = sess
b.dialect = migrator.NewDialect(engine.DriverName())
return nil
return b.db.PingContext(ctx)
}
func (b *backend) IsHealthy(ctx context.Context, r *resource.HealthCheckRequest) (*resource.HealthCheckResponse, error) {
// ctxLogger := s.log.FromContext(log.WithContextualAttributes(ctx, []any{"method", "isHealthy"}))
if err := b.sqlDB.PingContext(ctx); err != nil {
if err := b.db.PingContext(ctx); err != nil {
return nil, err
}
// TODO: check the status of the watcher implementation as well
return &resource.HealthCheckResponse{Status: resource.HealthCheckResponse_SERVING}, nil
}
func (b *backend) Stop() {
func (b *backend) Stop(_ context.Context) error {
b.cancel()
return nil
}
func (b *backend) WriteEvent(ctx context.Context, event resource.WriteEvent) (int64, error) {
_, span := b.tracer.Start(ctx, trace_prefix+"WriteEvent")
defer span.End()
// TODO: validate key ?
if err := b.Init(); err != nil {
return 0, err
}
switch event.Type {
case resource.WatchEvent_ADDED:
return b.create(ctx, event)
@ -146,12 +130,12 @@ func (b *backend) create(ctx context.Context, event resource.WriteEvent) (int64,
defer span.End()
var newVersion int64
guid := uuid.New().String()
err := b.sqlDB.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
err := b.db.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
// TODO: Set the Labels
// 1. Insert into resource
if _, err := exec(ctx, tx, sqlResourceInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err := dbutil.Exec(ctx, tx, sqlResourceInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
}); err != nil {
@ -159,8 +143,8 @@ func (b *backend) create(ctx context.Context, event resource.WriteEvent) (int64,
}
// 2. Insert into resource history
if _, err := exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err := dbutil.Exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
}); err != nil {
@ -169,28 +153,30 @@ func (b *backend) create(ctx context.Context, event resource.WriteEvent) (int64,
// 3. TODO: Rebuild the whole folder tree structure if we're creating a folder
// 4. Atomically increpement resource version for this kind
rv, err := resourceVersionAtomicInc(ctx, tx, b.sqlDialect, event.Key)
// 4. Atomically increment resource version for this kind
rv, err := resourceVersionAtomicInc(ctx, tx, b.dialect, event.Key)
if err != nil {
return err
return fmt.Errorf("increment resource version: %w", err)
}
newVersion = rv
// 5. Update the RV in both resource and resource_history
if _, err = exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err = dbutil.Exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.dialect),
GUID: guid,
ResourceVersion: newVersion,
ResourceVersion: rv,
}); err != nil {
return fmt.Errorf("update history rv: %w", err)
return fmt.Errorf("update resource_history rv: %w", err)
}
if _, err = exec(ctx, tx, sqlResourceUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err = dbutil.Exec(ctx, tx, sqlResourceUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.dialect),
GUID: guid,
ResourceVersion: newVersion,
ResourceVersion: rv,
}); err != nil {
return fmt.Errorf("update resource rv: %w", err)
}
newVersion = rv
return nil
})
@ -202,30 +188,22 @@ func (b *backend) update(ctx context.Context, event resource.WriteEvent) (int64,
defer span.End()
var newVersion int64
guid := uuid.New().String()
err := b.sqlDB.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
err := b.db.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
// TODO: Set the Labels
// 1. Update into resource
res, err := exec(ctx, tx, sqlResourceUpdate, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
// 1. Update resource
_, err := dbutil.Exec(ctx, tx, sqlResourceUpdate, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
})
if err != nil {
return fmt.Errorf("update into resource: %w", err)
}
count, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("update into resource: %w", err)
}
if count == 0 {
return fmt.Errorf("no rows affected")
return fmt.Errorf("initial resource update: %w", err)
}
// 2. Insert into resource history
if _, err := exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err := dbutil.Exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
}); err != nil {
@ -234,28 +212,29 @@ func (b *backend) update(ctx context.Context, event resource.WriteEvent) (int64,
// 3. TODO: Rebuild the whole folder tree structure if we're creating a folder
// 4. Atomically increpement resource version for this kind
rv, err := resourceVersionAtomicInc(ctx, tx, b.sqlDialect, event.Key)
// 4. Atomically increment resource version for this kind
rv, err := resourceVersionAtomicInc(ctx, tx, b.dialect, event.Key)
if err != nil {
return err
return fmt.Errorf("increment resource version: %w", err)
}
newVersion = rv
// 5. Update the RV in both resource and resource_history
if _, err = exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err = dbutil.Exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.dialect),
GUID: guid,
ResourceVersion: newVersion,
ResourceVersion: rv,
}); err != nil {
return fmt.Errorf("update history rv: %w", err)
}
if _, err = exec(ctx, tx, sqlResourceUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err = dbutil.Exec(ctx, tx, sqlResourceUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.dialect),
GUID: guid,
ResourceVersion: newVersion,
ResourceVersion: rv,
}); err != nil {
return fmt.Errorf("update resource rv: %w", err)
}
newVersion = rv
return nil
})
@ -269,29 +248,22 @@ func (b *backend) delete(ctx context.Context, event resource.WriteEvent) (int64,
var newVersion int64
guid := uuid.New().String()
err := b.sqlDB.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
err := b.db.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
// TODO: Set the Labels
// 1. delete from resource
res, err := exec(ctx, tx, sqlResourceDelete, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
_, err := dbutil.Exec(ctx, tx, sqlResourceDelete, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
})
if err != nil {
return fmt.Errorf("delete resource: %w", err)
}
count, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("delete resource: %w", err)
}
if count == 0 {
return fmt.Errorf("no rows affected")
}
// 2. Add event to resource history
if _, err := exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
// 2. Add event to resource history
if _, err := dbutil.Exec(ctx, tx, sqlResourceHistoryInsert, sqlResourceRequest{
SQLTemplate: sqltemplate.New(b.dialect),
WriteEvent: event,
GUID: guid,
}); err != nil {
@ -300,20 +272,22 @@ func (b *backend) delete(ctx context.Context, event resource.WriteEvent) (int64,
// 3. TODO: Rebuild the whole folder tree structure if we're creating a folder
// 4. Atomically increpement resource version for this kind
newVersion, err = resourceVersionAtomicInc(ctx, tx, b.sqlDialect, event.Key)
// 4. Atomically increment resource version for this kind
rv, err := resourceVersionAtomicInc(ctx, tx, b.dialect, event.Key)
if err != nil {
return err
return fmt.Errorf("increment resource version: %w", err)
}
// 5. Update the RV in resource_history
if _, err = exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
if _, err = dbutil.Exec(ctx, tx, sqlResourceHistoryUpdateRV, sqlResourceUpdateRVRequest{
SQLTemplate: sqltemplate.New(b.dialect),
GUID: guid,
ResourceVersion: newVersion,
ResourceVersion: rv,
}); err != nil {
return fmt.Errorf("update history rv: %w", err)
}
newVersion = rv
return nil
})
@ -325,12 +299,9 @@ func (b *backend) Read(ctx context.Context, req *resource.ReadRequest) (*resourc
defer span.End()
// TODO: validate key ?
if err := b.Init(); err != nil {
return nil, err
}
readReq := sqlResourceReadRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
SQLTemplate: sqltemplate.New(b.dialect),
Request: req,
readResponse: new(readResponse),
}
@ -341,7 +312,7 @@ func (b *backend) Read(ctx context.Context, req *resource.ReadRequest) (*resourc
sr = sqlResourceHistoryRead
}
res, err := queryRow(ctx, b.sqlDB, sr, readReq)
res, err := dbutil.QueryRow(ctx, b.db, sr, readReq)
if errors.Is(err, sql.ErrNoRows) {
return nil, resource.ErrNotFound
} else if err != nil {
@ -361,6 +332,8 @@ func (b *backend) PrepareList(ctx context.Context, req *resource.ListRequest) (*
// TODO: think about how to handler VersionMatch. We should be able to use latest for the first page (only).
// TODO: add support for RemainingItemCount
if req.ResourceVersion > 0 || req.NextPageToken != "" {
return b.listAtRevision(ctx, req)
}
@ -370,56 +343,43 @@ func (b *backend) PrepareList(ctx context.Context, req *resource.ListRequest) (*
// listLatest fetches the resources from the resource table.
func (b *backend) listLatest(ctx context.Context, req *resource.ListRequest) (*resource.ListResponse, error) {
out := &resource.ListResponse{
Items: []*resource.ResourceWrapper{}, // TODO: we could pre-allocate the capacity if we estimate the number of items
ResourceVersion: 0,
}
err := b.sqlDB.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
err := b.db.WithTx(ctx, ReadCommittedRO, func(ctx context.Context, tx db.Tx) error {
var err error
out.ResourceVersion, err = fetchLatestRV(ctx, tx, b.sqlDialect, req.Options.Key.Group, req.Options.Key.Resource)
out.ResourceVersion, err = fetchLatestRV(ctx, tx, b.dialect, req.Options.Key.Group, req.Options.Key.Resource)
if err != nil {
return err
}
// Fetch one extra row for Limit
lim := req.Limit
if req.Limit > 0 {
req.Limit++
}
listReq := sqlResourceListRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
Request: req,
SQLTemplate: sqltemplate.New(b.dialect),
Request: new(resource.ListRequest),
Response: new(resource.ResourceWrapper),
}
query, err := sqltemplate.Execute(sqlResourceList, listReq)
if err != nil {
return fmt.Errorf("execute SQL template to list resources: %w", err)
listReq.Request = proto.Clone(req).(*resource.ListRequest)
if req.Limit > 0 {
listReq.Request.Limit++ // fetch one extra row for Limit
}
rows, err := tx.QueryContext(ctx, query, listReq.GetArgs()...)
items, err := dbutil.Query(ctx, tx, sqlResourceList, listReq)
if err != nil {
return fmt.Errorf("list latest resources: %w", err)
}
defer func() { _ = rows.Close() }()
for i := int64(1); rows.Next(); i++ {
if ctx.Err() != nil {
return ctx.Err()
}
if err := rows.Scan(listReq.GetScanDest()...); err != nil {
return fmt.Errorf("scan row #%d: %w", i, err)
}
if lim > 0 && i > lim {
continueToken := &ContinueToken{ResourceVersion: out.ResourceVersion, StartOffset: lim}
out.NextPageToken = continueToken.String()
break
}
out.Items = append(out.Items, &resource.ResourceWrapper{
ResourceVersion: listReq.Response.ResourceVersion,
Value: listReq.Response.Value,
})
if 0 < req.Limit && int(req.Limit) < len(items) {
// remove the additional item we added synthetically above
clear(items[req.Limit:])
items = items[:req.Limit]
out.NextPageToken = ContinueToken{
ResourceVersion: out.ResourceVersion,
StartOffset: req.Limit,
}.String()
}
out.Items = items
return nil
})
@ -442,20 +402,12 @@ func (b *backend) listAtRevision(ctx context.Context, req *resource.ListRequest)
}
out := &resource.ListResponse{
Items: []*resource.ResourceWrapper{}, // TODO: we could pre-allocate the capacity if we estimate the number of items
ResourceVersion: rv,
}
err := b.sqlDB.WithTx(ctx, ReadCommitted, func(ctx context.Context, tx db.Tx) error {
var err error
// Fetch one extra row for Limit
lim := req.Limit
if lim > 0 {
req.Limit++
}
err := b.db.WithTx(ctx, ReadCommittedRO, func(ctx context.Context, tx db.Tx) error {
listReq := sqlResourceHistoryListRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
SQLTemplate: sqltemplate.New(b.dialect),
Request: &historyListRequest{
ResourceVersion: rv,
Limit: req.Limit,
@ -464,33 +416,26 @@ func (b *backend) listAtRevision(ctx context.Context, req *resource.ListRequest)
},
Response: new(resource.ResourceWrapper),
}
query, err := sqltemplate.Execute(sqlResourceHistoryList, listReq)
if err != nil {
return fmt.Errorf("execute SQL template to list resources at revision: %w", err)
if listReq.Request.Limit > 0 {
listReq.Request.Limit++ // fetch one extra row for Limit
}
rows, err := tx.QueryContext(ctx, query, listReq.GetArgs()...)
items, err := dbutil.Query(ctx, tx, sqlResourceHistoryList, listReq)
if err != nil {
return fmt.Errorf("list resources at revision: %w", err)
}
defer func() { _ = rows.Close() }()
for i := int64(1); rows.Next(); i++ {
if ctx.Err() != nil {
return ctx.Err()
}
if err := rows.Scan(listReq.GetScanDest()...); err != nil {
return fmt.Errorf("scan row #%d: %w", i, err)
}
if lim > 0 && i > lim {
continueToken := &ContinueToken{ResourceVersion: out.ResourceVersion, StartOffset: offset + lim}
out.NextPageToken = continueToken.String()
break
}
out.Items = append(out.Items, &resource.ResourceWrapper{
ResourceVersion: listReq.Response.ResourceVersion,
Value: listReq.Response.Value,
})
if 0 < req.Limit && int(req.Limit) < len(items) {
// remove the additional item we added synthetically above
clear(items[req.Limit:])
items = items[:req.Limit]
out.NextPageToken = ContinueToken{
ResourceVersion: out.ResourceVersion,
StartOffset: req.Limit + offset,
}.String()
}
out.Items = items
return nil
})
@ -499,9 +444,6 @@ func (b *backend) listAtRevision(ctx context.Context, req *resource.ListRequest)
}
func (b *backend) WatchWriteEvents(ctx context.Context) (<-chan *resource.WrittenEvent, error) {
if err := b.Init(); err != nil {
return nil, err
}
// Get the latest RV
since, err := b.listLatestRVs(ctx)
if err != nil {
@ -521,7 +463,7 @@ func (b *backend) poller(ctx context.Context, since groupResourceRV, stream chan
for {
select {
case <-b.ctx.Done():
case <-b.done:
return
case <-t.C:
// List the latest RVs
@ -561,14 +503,14 @@ func (b *backend) poller(ctx context.Context, since groupResourceRV, stream chan
func (b *backend) listLatestRVs(ctx context.Context) (groupResourceRV, error) {
since := groupResourceRV{}
reqRVs := sqlResourceVersionListRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
SQLTemplate: sqltemplate.New(b.dialect),
groupResourceVersion: new(groupResourceVersion),
}
query, err := sqltemplate.Execute(sqlResourceVersionList, reqRVs)
if err != nil {
return nil, fmt.Errorf("execute SQL template to get the latest resource version: %w", err)
}
rows, err := b.sqlDB.QueryContext(ctx, query, reqRVs.GetArgs()...)
rows, err := b.db.QueryContext(ctx, query, reqRVs.GetArgs()...)
if err != nil {
return nil, fmt.Errorf("fetching recent resource versions: %w", err)
}
@ -591,7 +533,7 @@ func (b *backend) listLatestRVs(ctx context.Context) (groupResourceRV, error) {
// fetchLatestRV returns the current maximum RV in the resource table
func fetchLatestRV(ctx context.Context, x db.ContextExecer, d sqltemplate.Dialect, group, resource string) (int64, error) {
res, err := queryRow(ctx, x, sqlResourceVersionGet, sqlResourceVersionRequest{
res, err := dbutil.QueryRow(ctx, x, sqlResourceVersionGet, sqlResourceVersionRequest{
SQLTemplate: sqltemplate.New(d),
Group: group,
Resource: resource,
@ -610,7 +552,7 @@ func (b *backend) poll(ctx context.Context, grp string, res string, since int64,
defer span.End()
pollReq := sqlResourceHistoryPollRequest{
SQLTemplate: sqltemplate.New(b.sqlDialect),
SQLTemplate: sqltemplate.New(b.dialect),
Resource: res,
Group: grp,
SinceResourceVersion: since,
@ -620,8 +562,7 @@ func (b *backend) poll(ctx context.Context, grp string, res string, since int64,
if err != nil {
return since, fmt.Errorf("execute SQL template to poll for resource history: %w", err)
}
rows, err := b.sqlDB.QueryContext(ctx, query, pollReq.GetArgs()...)
rows, err := b.db.QueryContext(ctx, query, pollReq.GetArgs()...)
if err != nil {
return since, fmt.Errorf("poll for resource history: %w", err)
}
@ -667,18 +608,17 @@ func (b *backend) poll(ctx context.Context, grp string, res string, since int64,
func resourceVersionAtomicInc(ctx context.Context, x db.ContextExecer, d sqltemplate.Dialect, key *resource.ResourceKey) (newVersion int64, err error) {
// TODO: refactor this code to run in a multi-statement transaction in order to minimise the number of roundtrips.
// 1 Lock the row for update
req := sqlResourceVersionRequest{
rv, err := dbutil.QueryRow(ctx, x, sqlResourceVersionGet, sqlResourceVersionRequest{
SQLTemplate: sqltemplate.New(d),
Group: key.Group,
Resource: key.Resource,
resourceVersion: new(resourceVersion),
}
rv, err := queryRow(ctx, x, sqlResourceVersionGet, req)
})
if errors.Is(err, sql.ErrNoRows) {
// if there wasn't a row associated with the given resource, we create one with
// version 1
if _, err = exec(ctx, x, sqlResourceVersionInsert, sqlResourceVersionRequest{
if _, err = dbutil.Exec(ctx, x, sqlResourceVersionInsert, sqlResourceVersionRequest{
SQLTemplate: sqltemplate.New(d),
Group: key.Group,
Resource: key.Resource,
@ -687,12 +627,14 @@ func resourceVersionAtomicInc(ctx context.Context, x db.ContextExecer, d sqltemp
}
return 1, nil
}
if err != nil {
return 0, fmt.Errorf("increase resource version: %w", err)
return 0, fmt.Errorf("get current resource version: %w", err)
}
nextRV := rv.ResourceVersion + 1
// 2. Increment the resource version
res, err := exec(ctx, x, sqlResourceVersionInc, sqlResourceVersionRequest{
_, err = dbutil.Exec(ctx, x, sqlResourceVersionInc, sqlResourceVersionRequest{
SQLTemplate: sqltemplate.New(d),
Group: key.Group,
Resource: key.Resource,
@ -704,90 +646,6 @@ func resourceVersionAtomicInc(ctx context.Context, x db.ContextExecer, d sqltemp
return 0, fmt.Errorf("increase resource version: %w", err)
}
if count, err := res.RowsAffected(); err != nil || count == 0 {
return 0, fmt.Errorf("increase resource version did not affect any rows: %w", err)
}
// 3. Retun the incremended value
return nextRV, nil
}
// exec uses `req` as input for a non-data returning query generated with
// `tmpl`, and executed in `x`.
func exec(ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.SQLTemplateIface) (sql.Result, error) {
if err := req.Validate(); err != nil {
return nil, fmt.Errorf("exec: invalid request for template %q: %w",
tmpl.Name(), err)
}
rawQuery, err := sqltemplate.Execute(tmpl, req)
if err != nil {
return nil, fmt.Errorf("execute template: %w", err)
}
query := sqltemplate.FormatSQL(rawQuery)
res, err := x.ExecContext(ctx, query, req.GetArgs()...)
if err != nil {
return nil, SQLError{
Err: err,
CallType: "Exec",
TemplateName: tmpl.Name(),
arguments: req.GetArgs(),
Query: query,
RawQuery: rawQuery,
}
}
return res, nil
}
// queryRow uses `req` as input and output for a single-row returning query
// generated with `tmpl`, and executed in `x`.
func queryRow[T any](ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.WithResults[T]) (T, error) {
var zero T
if err := req.Validate(); err != nil {
return zero, fmt.Errorf("query: invalid request for template %q: %w",
tmpl.Name(), err)
}
rawQuery, err := sqltemplate.Execute(tmpl, req)
if err != nil {
return zero, fmt.Errorf("execute template: %w", err)
}
query := sqltemplate.FormatSQL(rawQuery)
row := x.QueryRowContext(ctx, query, req.GetArgs()...)
if err := row.Err(); err != nil {
return zero, SQLError{
Err: err,
CallType: "QueryRow",
TemplateName: tmpl.Name(),
arguments: req.GetArgs(),
ScanDest: req.GetScanDest(),
Query: query,
RawQuery: rawQuery,
}
}
return scanRow(row, req)
}
type scanner interface {
Scan(dest ...any) error
}
// scanRow is used on *sql.Row and *sql.Rows, and is factored out here not to
// improving code reuse, but rather for ease of testing.
func scanRow[T any](sc scanner, req sqltemplate.WithResults[T]) (zero T, err error) {
if err = sc.Scan(req.GetScanDest()...); err != nil {
return zero, fmt.Errorf("row scan: %w", err)
}
res, err := req.Results()
if err != nil {
return zero, fmt.Errorf("row results: %w", err)
}
return res, nil
}

View File

@ -2,292 +2,579 @@ package sql
import (
"context"
"database/sql/driver"
"errors"
"testing"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/setting"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/storage/unified/resource"
"github.com/grafana/grafana/pkg/storage/unified/sql/db/dbimpl"
"github.com/grafana/grafana/pkg/tests/testsuite"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
"github.com/grafana/grafana/pkg/storage/unified/sql/test"
"github.com/grafana/grafana/pkg/util/testutil"
)
func TestMain(m *testing.M) {
testsuite.Run(m)
}
func TestBackendHappyPath(t *testing.T) {
ctx := context.Background()
dbstore := db.InitTestDB(t)
rdb, err := dbimpl.ProvideResourceDB(dbstore, setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagUnifiedStorage), nil)
assert.NoError(t, err)
store, err := NewBackendStore(backendOptions{
DB: rdb,
})
assert.NoError(t, err)
assert.NotNil(t, store)
stream, err := store.WatchWriteEvents(ctx)
assert.NoError(t, err)
t.Run("Add 3 resources", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(1), rv)
rv, err = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(2), rv)
rv, err = writeEvent(ctx, store, "item3", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(3), rv)
})
t.Run("Update item2", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item2", resource.WatchEvent_MODIFIED)
assert.NoError(t, err)
assert.Equal(t, int64(4), rv)
})
t.Run("Delete item1", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item1", resource.WatchEvent_DELETED)
assert.NoError(t, err)
assert.Equal(t, int64(5), rv)
})
t.Run("Read latest item 2", func(t *testing.T) {
resp, err := store.Read(ctx, &resource.ReadRequest{Key: resourceKey("item2")})
assert.NoError(t, err)
assert.Equal(t, int64(4), resp.ResourceVersion)
assert.Equal(t, "item2 MODIFIED", string(resp.Value))
})
t.Run("Read early verion of item2", func(t *testing.T) {
resp, err := store.Read(ctx, &resource.ReadRequest{
Key: resourceKey("item2"),
ResourceVersion: 3, // item2 was created at rv=2 and updated at rv=4
})
assert.NoError(t, err)
assert.Equal(t, int64(2), resp.ResourceVersion)
assert.Equal(t, "item2 ADDED", string(resp.Value))
})
t.Run("PrepareList latest", func(t *testing.T) {
resp, err := store.PrepareList(ctx, &resource.ListRequest{
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, resp.Items, 2)
assert.Equal(t, "item2 MODIFIED", string(resp.Items[0].Value))
assert.Equal(t, "item3 ADDED", string(resp.Items[1].Value))
assert.Equal(t, int64(5), resp.ResourceVersion)
})
t.Run("Watch events", func(t *testing.T) {
event := <-stream
assert.Equal(t, "item1", event.Key.Name)
assert.Equal(t, int64(1), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item2", event.Key.Name)
assert.Equal(t, int64(2), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item3", event.Key.Name)
assert.Equal(t, int64(3), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item2", event.Key.Name)
assert.Equal(t, int64(4), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_MODIFIED, event.Type)
event = <-stream
assert.Equal(t, "item1", event.Key.Name)
assert.Equal(t, int64(5), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_DELETED, event.Type)
})
}
func TestBackendWatchWriteEventsFromLastest(t *testing.T) {
ctx := context.Background()
dbstore := db.InitTestDB(t)
rdb, err := dbimpl.ProvideResourceDB(dbstore, setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagUnifiedStorage), nil)
assert.NoError(t, err)
store, err := NewBackendStore(backendOptions{
DB: rdb,
})
assert.NoError(t, err)
assert.NotNil(t, store)
// Create a few resources before initing the watch
_, err = writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED)
assert.NoError(t, err)
// Start the watch
stream, err := store.WatchWriteEvents(ctx)
assert.NoError(t, err)
// Create one more event
_, err = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, "item2", (<-stream).Key.Name)
}
func TestBackendPrepareList(t *testing.T) {
ctx := context.Background()
dbstore := db.InitTestDB(t)
rdb, err := dbimpl.ProvideResourceDB(dbstore, setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagUnifiedStorage), nil)
assert.NoError(t, err)
store, err := NewBackendStore(backendOptions{
DB: rdb,
})
assert.NoError(t, err)
assert.NotNil(t, store)
// Create a few resources before initing the watch
_, _ = writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED) // rv=1
_, _ = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED) // rv=2 - will be modified at rv=6
_, _ = writeEvent(ctx, store, "item3", resource.WatchEvent_ADDED) // rv=3 - will be deleted at rv=7
_, _ = writeEvent(ctx, store, "item4", resource.WatchEvent_ADDED) // rv=4
_, _ = writeEvent(ctx, store, "item5", resource.WatchEvent_ADDED) // rv=5
_, _ = writeEvent(ctx, store, "item2", resource.WatchEvent_MODIFIED) // rv=6
_, _ = writeEvent(ctx, store, "item3", resource.WatchEvent_DELETED) // rv=7
_, _ = writeEvent(ctx, store, "item6", resource.WatchEvent_ADDED) // rv=8
t.Run("fetch all latest", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 5)
assert.Empty(t, res.NextPageToken)
})
t.Run("list latest first page ", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Limit: 3,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 3)
continueToken, err := GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(8), continueToken.ResourceVersion)
assert.Equal(t, int64(3), continueToken.StartOffset)
})
t.Run("list at revision", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
ResourceVersion: 4,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 4)
assert.Equal(t, "item1 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item2 ADDED", string(res.Items[1].Value))
assert.Equal(t, "item3 ADDED", string(res.Items[2].Value))
assert.Equal(t, "item4 ADDED", string(res.Items[3].Value))
assert.Empty(t, res.NextPageToken)
})
t.Run("fetch first page at revision with limit", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Limit: 3,
ResourceVersion: 7,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 3)
assert.Equal(t, "item1 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item4 ADDED", string(res.Items[1].Value))
assert.Equal(t, "item5 ADDED", string(res.Items[2].Value))
continueToken, err := GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(7), continueToken.ResourceVersion)
assert.Equal(t, int64(3), continueToken.StartOffset)
})
t.Run("fetch second page at revision", func(t *testing.T) {
continueToken := &ContinueToken{
ResourceVersion: 8,
StartOffset: 2,
}
res, err := store.PrepareList(ctx, &resource.ListRequest{
NextPageToken: continueToken.String(),
Limit: 2,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 2)
assert.Equal(t, "item5 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item2 MODIFIED", string(res.Items[1].Value))
continueToken, err = GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(8), continueToken.ResourceVersion)
assert.Equal(t, int64(4), continueToken.StartOffset)
})
}
func writeEvent(ctx context.Context, store *backend, name string, action resource.WatchEvent_Type) (int64, error) {
return store.WriteEvent(ctx, resource.WriteEvent{
Type: action,
Value: []byte(name + " " + resource.WatchEvent_Type_name[int32(action)]),
Key: &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
Name: name,
},
})
}
func resourceKey(name string) *resource.ResourceKey {
return &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
Name: name,
var (
errTest = errors.New("things happened")
resKey = &resource.ResourceKey{
Namespace: "ns",
Group: "gr",
Resource: "rs",
Name: "nm",
}
)
type (
Cols = []string // column names
Rows = [][]driver.Value // row values returned
)
type testBackend struct {
*backend
test.TestDBProvider
}
func (b testBackend) ExecWithResult(expectedSQL string) {
b.SQLMock.ExpectExec(expectedSQL).WillReturnResult(sqlmock.NewResult(0, 0))
}
func (b testBackend) ExecWithErr(expectedSQL string, err error) {
b.SQLMock.ExpectExec(expectedSQL).WillReturnError(err)
}
func (b testBackend) QueryWithResult(expectedSQL string, numCols int, rs Rows) {
rows := b.SQLMock.NewRows(make([]string, numCols))
if len(rs) > 0 {
rows = rows.AddRows(rs...)
}
b.SQLMock.ExpectQuery(expectedSQL).WillReturnRows(rows)
}
func (b testBackend) QueryWithErr(expectedSQL string, err error) {
b.SQLMock.ExpectQuery(expectedSQL).WillReturnError(err)
}
func setupBackendTest(t *testing.T) (testBackend, context.Context) {
t.Helper()
ctx := testutil.NewDefaultTestContext(t)
dbp := test.NewDBProviderMatchWords(t)
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, b)
err = b.Init(ctx)
require.NoError(t, err)
bb, ok := b.(*backend)
require.True(t, ok)
require.NotNil(t, bb)
return testBackend{
backend: bb,
TestDBProvider: dbp,
}, ctx
}
func TestNewBackend(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
dbp := test.NewDBProviderNopSQL(t)
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, b)
})
t.Run("no db provider", func(t *testing.T) {
t.Parallel()
b, err := NewBackend(BackendOptions{})
require.Nil(t, b)
require.Error(t, err)
require.ErrorContains(t, err, "no db provider")
})
}
func TestBackend_Init(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
ctx := testutil.NewDefaultTestContext(t)
dbp := test.NewDBProviderWithPing(t)
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, b)
dbp.SQLMock.ExpectPing().WillReturnError(nil)
err = b.Init(ctx)
require.NoError(t, err)
// if it isn't idempotent, then it will make a second ping and the
// expectation will fail
err = b.Init(ctx)
require.NoError(t, err, "should be idempotent")
err = b.Stop(ctx)
require.NoError(t, err)
})
t.Run("no db provider", func(t *testing.T) {
t.Parallel()
ctx := testutil.NewDefaultTestContext(t)
dbp := test.TestDBProvider{
Err: errTest,
}
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, b)
err = b.Init(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "initialize resource DB")
})
t.Run("no dialect for driver", func(t *testing.T) {
t.Parallel()
ctx := testutil.NewDefaultTestContext(t)
mockDB, _, err := sqlmock.New()
require.NoError(t, err)
dbp := test.TestDBProvider{
DB: dbimpl.NewDB(mockDB, "juancarlo"),
}
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, b)
err = b.Init(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "no dialect for driver")
})
t.Run("database unreachable", func(t *testing.T) {
t.Parallel()
ctx := testutil.NewDefaultTestContext(t)
dbp := test.NewDBProviderWithPing(t)
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, dbp.DB)
dbp.SQLMock.ExpectPing().WillReturnError(errTest)
err = b.Init(ctx)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
}
func TestBackend_IsHealthy(t *testing.T) {
t.Parallel()
ctx := testutil.NewDefaultTestContext(t)
dbp := test.NewDBProviderWithPing(t)
b, err := NewBackend(BackendOptions{DBProvider: dbp})
require.NoError(t, err)
require.NotNil(t, dbp.DB)
dbp.SQLMock.ExpectPing().WillReturnError(nil)
err = b.Init(ctx)
require.NoError(t, err)
dbp.SQLMock.ExpectPing().WillReturnError(nil)
res, err := b.IsHealthy(ctx, nil)
require.NoError(t, err)
require.NotNil(t, res)
dbp.SQLMock.ExpectPing().WillReturnError(errTest)
res, err = b.IsHealthy(ctx, nil)
require.Nil(t, res)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
}
// expectSuccessfulResourceVersionAtomicInc sets up expectations for calling
// resourceVersionAtomicInc, where the returned RV will be 1.
func expectSuccessfulResourceVersionAtomicInc(t *testing.T, b testBackend) {
b.QueryWithResult("select resource_version for update", 0, nil)
b.ExecWithResult("insert resource_version")
}
// expectUnsuccessfulResourceVersionAtomicInc sets up expectations for calling
// resourceVersionAtomicInc, where the returned RV will be 1.
func expectUnsuccessfulResourceVersionAtomicInc(t *testing.T, b testBackend, err error) {
b.QueryWithErr("select resource_version for update", errTest)
}
func TestResourceVersionAtomicInc(t *testing.T) {
t.Parallel()
dialect := sqltemplate.MySQL
t.Run("happy path - insert new row", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
v, err := resourceVersionAtomicInc(ctx, b.DB, dialect, resKey)
require.NoError(t, err)
require.Equal(t, int64(1), v)
})
t.Run("happy path - update existing row", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.QueryWithResult("select resource_version for update", 1, Rows{{2}})
b.ExecWithResult("update resource_version")
v, err := resourceVersionAtomicInc(ctx, b.DB, dialect, resKey)
require.NoError(t, err)
require.Equal(t, int64(3), v)
})
t.Run("error getting current version", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.QueryWithErr("select resource_version for update", errTest)
v, err := resourceVersionAtomicInc(ctx, b.DB, dialect, resKey)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "get current resource version")
})
t.Run("error inserting new row", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.QueryWithResult("select resource_version for update", 0, nil)
b.ExecWithErr("insert resource_version", errTest)
v, err := resourceVersionAtomicInc(ctx, b.DB, dialect, resKey)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "insert into resource_version")
})
t.Run("error updating existing row", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.QueryWithResult("select resource_version for update", 1, Rows{{2}})
b.ExecWithErr("update resource_version", errTest)
v, err := resourceVersionAtomicInc(ctx, b.DB, dialect, resKey)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "increase resource version")
})
}
func TestBackend_create(t *testing.T) {
t.Parallel()
event := resource.WriteEvent{
Type: resource.WatchEvent_ADDED,
Key: resKey,
}
t.Run("happy path", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("insert resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithResult("update resource_history")
b.ExecWithResult("update resource")
b.SQLMock.ExpectCommit()
v, err := b.create(ctx, event)
require.NoError(t, err)
require.Equal(t, int64(1), v)
})
t.Run("error inserting into resource", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithErr("insert resource", errTest)
b.SQLMock.ExpectRollback()
v, err := b.create(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "insert into resource:")
})
t.Run("error inserting into resource_history", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("insert resource")
b.ExecWithErr("insert resource_history", errTest)
b.SQLMock.ExpectRollback()
v, err := b.create(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "insert into resource history:")
})
t.Run("error incrementing resource version", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("insert resource")
b.ExecWithResult("insert resource_history")
expectUnsuccessfulResourceVersionAtomicInc(t, b, errTest)
b.SQLMock.ExpectRollback()
v, err := b.create(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "increment resource version")
})
t.Run("error updating resource_history", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("insert resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b)
b.ExecWithErr("update resource_history", errTest)
b.SQLMock.ExpectRollback()
v, err := b.create(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "update resource_history")
})
t.Run("error updating resource", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("insert resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b)
b.ExecWithResult("update resource_history")
b.ExecWithErr("update resource", errTest)
b.SQLMock.ExpectRollback()
v, err := b.create(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "update resource rv")
})
}
func TestBackend_update(t *testing.T) {
t.Parallel()
event := resource.WriteEvent{
Type: resource.WatchEvent_MODIFIED,
Key: resKey,
}
t.Run("happy path", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("update resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithResult("update resource_history")
b.ExecWithResult("update resource")
b.SQLMock.ExpectCommit()
v, err := b.update(ctx, event)
require.NoError(t, err)
require.Equal(t, int64(1), v)
})
t.Run("error in first update to resource", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithErr("update resource", errTest)
b.SQLMock.ExpectRollback()
v, err := b.update(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "initial resource update")
})
t.Run("error inserting into resource history", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("update resource")
b.ExecWithErr("insert resource_history", errTest)
b.SQLMock.ExpectRollback()
v, err := b.update(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "insert into resource history")
})
t.Run("error incrementing rv", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("update resource")
b.ExecWithResult("insert resource_history")
expectUnsuccessfulResourceVersionAtomicInc(t, b, errTest)
b.SQLMock.ExpectRollback()
v, err := b.update(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "increment resource version")
})
t.Run("error updating history rv", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("update resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithErr("update resource_history", errTest)
b.SQLMock.ExpectRollback()
v, err := b.update(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "update history rv")
})
t.Run("error updating resource rv", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("update resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithResult("update resource_history")
b.ExecWithErr("update resource", errTest)
b.SQLMock.ExpectRollback()
v, err := b.update(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "update resource rv")
})
}
func TestBackend_delete(t *testing.T) {
t.Parallel()
event := resource.WriteEvent{
Type: resource.WatchEvent_DELETED,
Key: resKey,
}
t.Run("happy path", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("delete resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithResult("update resource_history")
b.SQLMock.ExpectCommit()
v, err := b.delete(ctx, event)
require.NoError(t, err)
require.Equal(t, int64(1), v)
})
t.Run("error deleting resource", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithErr("delete resource", errTest)
b.SQLMock.ExpectCommit()
v, err := b.delete(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "delete resource")
})
t.Run("error inserting into resource history", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("delete resource")
b.ExecWithErr("insert resource_history", errTest)
b.SQLMock.ExpectCommit()
v, err := b.delete(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "insert into resource history")
})
t.Run("error incrementing resource version", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("delete resource")
b.ExecWithResult("insert resource_history")
expectUnsuccessfulResourceVersionAtomicInc(t, b, errTest)
b.SQLMock.ExpectCommit()
v, err := b.delete(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "increment resource version")
})
t.Run("error updating resource history", func(t *testing.T) {
t.Parallel()
b, ctx := setupBackendTest(t)
b.SQLMock.ExpectBegin()
b.ExecWithResult("delete resource")
b.ExecWithResult("insert resource_history")
expectSuccessfulResourceVersionAtomicInc(t, b) // returns RV=1
b.ExecWithErr("update resource_history", errTest)
b.SQLMock.ExpectCommit()
v, err := b.delete(ctx, event)
require.Zero(t, v)
require.Error(t, err)
require.ErrorContains(t, err, "update history rv")
})
}

View File

@ -11,7 +11,7 @@ type ContinueToken struct {
ResourceVersion int64 `json:"v"`
}
func (c *ContinueToken) String() string {
func (c ContinueToken) String() string {
b, _ := json.Marshal(c)
return base64.StdEncoding.EncodeToString(b)
}

View File

@ -4,11 +4,16 @@ import (
"context"
"database/sql"
"fmt"
"strings"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
func NewDB(d *sql.DB, driverName string) resourcedb.DB {
func NewDB(d *sql.DB, driverName string) db.DB {
// remove the suffix from the instrumented driver created by the older
// Grafana code
driverName = strings.TrimSuffix(driverName, "WithHooks")
return sqldb{
DB: d,
driverName: driverName,
@ -24,7 +29,7 @@ func (d sqldb) DriverName() string {
return d.driverName
}
func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (resourcedb.Tx, error) {
func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (db.Tx, error) {
t, err := d.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
@ -34,7 +39,7 @@ func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (resourcedb.Tx,
}, nil
}
func (d sqldb) WithTx(ctx context.Context, opts *sql.TxOptions, f resourcedb.TxFunc) error {
func (d sqldb) WithTx(ctx context.Context, opts *sql.TxOptions, f db.TxFunc) error {
t, err := d.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("begin tx: %w", err)

View File

@ -7,13 +7,13 @@ import (
"time"
"github.com/go-sql-driver/mysql"
"go.opentelemetry.io/otel/trace"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/store/entity/db"
)
func getEngineMySQL(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
func getEngineMySQL(getter *sectionGetter, _ trace.Tracer) (*xorm.Engine, error) {
config := mysql.NewConfig()
config.User = getter.String("db_user")
config.Passwd = getter.String("db_pass")
@ -29,6 +29,11 @@ func getEngineMySQL(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, erro
config.AllowNativePasswords = true
config.ClientFoundRows = true
// allow executing multiple SQL statements in a single roundtrip, and also
// enable executing the CALL statement to run stored procedures that execute
// multiple SQL statements.
//config.MultiStatements = true
// TODO: do we want to support these?
// config.ServerPubKey = getter.String("db_server_pub_key")
// config.TLSConfig = getter.String("db_tls_config_name")
@ -54,7 +59,7 @@ func getEngineMySQL(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, erro
return engine, nil
}
func getEnginePostgres(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
func getEnginePostgres(getter *sectionGetter, _ trace.Tracer) (*xorm.Engine, error) {
dsnKV := map[string]string{
"user": getter.String("db_user"),
"password": getter.String("db_pass"),

View File

@ -6,20 +6,22 @@ import (
"github.com/stretchr/testify/assert"
)
func newValidMySQLGetter() *sectionGetter {
return newTestSectionGetter(map[string]string{
"db_type": dbTypeMySQL,
"db_host": "/var/run/mysql.socket",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
}
func TestGetEngineMySQLFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "/var/run/mysql.socket",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEngineMySQL(getter, nil)
engine, err := getEngineMySQL(newValidMySQLGetter(), nil)
assert.NotNil(t, engine)
assert.NoError(t, err)
})
@ -28,7 +30,7 @@ func TestGetEngineMySQLFromConfig(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_type": dbTypeMySQL,
"db_host": "/var/run/mysql.socket",
"db_name": string(invalidUTF8ByteSequence),
"db_user": "user",
@ -37,7 +39,17 @@ func TestGetEngineMySQLFromConfig(t *testing.T) {
engine, err := getEngineMySQL(getter, nil)
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
assert.ErrorIs(t, err, errInvalidUTF8Sequence)
})
}
func newValidPostgresGetter() *sectionGetter {
return newTestSectionGetter(map[string]string{
"db_type": dbTypePostgres,
"db_host": "localhost",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
}
@ -46,15 +58,7 @@ func TestGetEnginePostgresFromConfig(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "localhost",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEnginePostgres(getter, nil)
engine, err := getEnginePostgres(newValidPostgresGetter(), nil)
assert.NotNil(t, engine)
assert.NoError(t, err)
})
@ -62,7 +66,7 @@ func TestGetEnginePostgresFromConfig(t *testing.T) {
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_type": dbTypePostgres,
"db_host": string(invalidUTF8ByteSequence),
"db_name": "grafana",
"db_user": "user",
@ -72,13 +76,13 @@ func TestGetEnginePostgresFromConfig(t *testing.T) {
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
assert.ErrorIs(t, err, errInvalidUTF8Sequence)
})
t.Run("invalid hostport", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_type": dbTypePostgres,
"db_host": "1:1:1",
"db_name": "grafana",
"db_user": "user",

View File

@ -9,7 +9,7 @@ import (
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
func newCtx(t *testing.T) context.Context {
@ -71,8 +71,8 @@ func TestDB_BeginTx(t *testing.T) {
func TestDB_WithTx(t *testing.T) {
t.Parallel()
newTxFunc := func(err error) resourcedb.TxFunc {
return func(context.Context, resourcedb.Tx) error {
newTxFunc := func(err error) db.TxFunc {
return func(context.Context, db.Tx) error {
return err
}
}

View File

@ -1,166 +1,116 @@
package dbimpl
import (
"context"
"fmt"
"sync"
"github.com/dlmiddlecote/sqlstats"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/infra/db"
infraDB "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/setting"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db/migrations"
)
var _ resourcedb.ResourceDBInterface = (*ResourceDB)(nil)
const (
dbTypeMySQL = "mysql"
dbTypePostgres = "postgres"
)
func ProvideResourceDB(db db.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer) (*ResourceDB, error) {
return &ResourceDB{
db: db,
cfg: cfg,
features: features,
log: log.New("entity-db"),
tracer: tracer,
}, nil
}
type ResourceDB struct {
once sync.Once
onceErr error
db db.DB
features featuremgmt.FeatureToggles
engine *xorm.Engine
cfg *setting.Cfg
log log.Logger
tracer tracing.Tracer
}
func (db *ResourceDB) Init() error {
db.once.Do(func() {
db.onceErr = db.init()
})
return db.onceErr
}
func (db *ResourceDB) GetEngine() (*xorm.Engine, error) {
if err := db.Init(); err != nil {
return nil, err
func ProvideResourceDB(grafanaDB infraDB.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer trace.Tracer) (db.DBProvider, error) {
p, err := newResourceDBProvider(grafanaDB, cfg, features, tracer)
if err != nil {
return nil, fmt.Errorf("provide Resource DB: %w", err)
}
var once sync.Once
var resourceDB db.DB
return db.engine, db.onceErr
return dbProviderFunc(func(ctx context.Context) (db.DB, error) {
once.Do(func() {
resourceDB, err = p.init(ctx)
})
return resourceDB, err
}), nil
}
func (db *ResourceDB) init() error {
if db.engine != nil {
return nil
}
type dbProviderFunc func(context.Context) (db.DB, error)
var engine *xorm.Engine
var err error
func (f dbProviderFunc) Init(ctx context.Context) (db.DB, error) {
return f(ctx)
}
type resourceDBProvider struct {
engine *xorm.Engine
cfg *setting.Cfg
log log.Logger
migrateFunc func(context.Context, *xorm.Engine, *setting.Cfg) error
registerMetrics bool
logQueries bool
}
func newResourceDBProvider(grafanaDB infraDB.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer trace.Tracer) (p *resourceDBProvider, err error) {
// TODO: This should be renamed resource_api
getter := &sectionGetter{
DynamicSection: db.cfg.SectionWithEnvOverrides("resource_api"),
DynamicSection: cfg.SectionWithEnvOverrides("resource_api"),
}
dbType := getter.Key("db_type").MustString("")
// if explicit connection settings are provided, use them
if dbType != "" {
if dbType == "postgres" {
engine, err = getEnginePostgres(getter, db.tracer)
if err != nil {
return err
}
// FIXME: this config option is cockroachdb-specific, it's not supported by postgres
// FIXME: this only sets this option for the session that we get
// from the pool right now. A *sql.DB is a pool of connections,
// there is no guarantee that the session where this is run will be
// the same where we need to change the type of a column
_, err = engine.Exec("SET SESSION enable_experimental_alter_column_type_general=true")
if err != nil {
db.log.Error("error connecting to postgres", "msg", err.Error())
// FIXME: return nil, err
}
} else if dbType == "mysql" {
engine, err = getEngineMySQL(getter, db.tracer)
if err != nil {
return err
}
if err = engine.Ping(); err != nil {
return err
}
} else {
// TODO: sqlite support
return fmt.Errorf("invalid db type specified: %s", dbType)
}
// register sql stat metrics
if err := prometheus.Register(sqlstats.NewStatsCollector("unified_storage", engine.DB().DB)); err != nil {
db.log.Warn("Failed to register unified storage sql stats collector", "error", err)
}
// configure sql logging
debugSQL := getter.Key("log_queries").MustBool(false)
if !debugSQL {
engine.SetLogger(&xorm.DiscardLogger{})
} else {
// add stack to database calls to be able to see what repository initiated queries. Top 7 items from the stack as they are likely in the xorm library.
// engine.SetLogger(sqlstore.NewXormLogger(log.LvlInfo, log.WithSuffix(log.New("sqlstore.xorm"), log.CallerContextKey, log.StackCaller(log.DefaultCallerDepth))))
engine.ShowSQL(true)
engine.ShowExecTime(true)
}
// otherwise, try to use the grafana db connection
} else {
if db.db == nil {
return fmt.Errorf("no db connection provided")
}
engine = db.db.GetEngine()
p = &resourceDBProvider{
cfg: cfg,
log: log.New("entity-db"),
logQueries: getter.Key("log_queries").MustBool(false),
}
if features.IsEnabledGlobally(featuremgmt.FlagUnifiedStorage) {
p.migrateFunc = migrations.MigrateResourceStore
}
db.engine = engine
switch dbType := getter.Key("db_type").MustString(""); dbType {
case dbTypePostgres:
p.registerMetrics = true
p.engine, err = getEnginePostgres(getter, tracer)
return p, err
if err := migrations.MigrateResourceStore(engine, db.cfg, db.features); err != nil {
db.engine = nil
return fmt.Errorf("run migrations: %w", err)
case dbTypeMySQL:
p.registerMetrics = true
p.engine, err = getEngineMySQL(getter, tracer)
return p, err
case "":
// try to use the grafana db connection
if grafanaDB == nil {
return p, fmt.Errorf("no db connection provided")
}
p.engine = grafanaDB.GetEngine()
return p, nil
default:
// TODO: sqlite support
return p, fmt.Errorf("invalid db type specified: %s", dbType)
}
return nil
}
func (db *ResourceDB) GetSession() (*session.SessionDB, error) {
engine, err := db.GetEngine()
if err != nil {
return nil, err
func (p *resourceDBProvider) init(ctx context.Context) (db.DB, error) {
if p.registerMetrics {
err := prometheus.Register(sqlstats.NewStatsCollector("unified_storage", p.engine.DB().DB))
if err != nil {
p.log.Warn("Failed to register unified storage sql stats collector", "error", err)
}
}
_ = p.logQueries // TODO: configure SQL logging
// TODO: change the migrator to use db.DB instead of xorm
// Skip migrations if feature flag is not enabled
if p.migrateFunc != nil {
err := p.migrateFunc(ctx, p.engine, p.cfg)
if err != nil {
return nil, fmt.Errorf("run migrations: %w", err)
}
}
return session.GetSession(sqlx.NewDb(engine.DB().DB, engine.DriverName())), nil
}
func (db *ResourceDB) GetCfg() *setting.Cfg {
return db.cfg
}
func (db *ResourceDB) GetDB() (resourcedb.DB, error) {
engine, err := db.GetEngine()
if err != nil {
return nil, err
}
ret := NewDB(engine.DB().DB, engine.Dialect().DriverName())
return ret, nil
return NewDB(p.engine.DB().DB, p.engine.Dialect().DriverName()), nil
}

View File

@ -12,9 +12,7 @@ import (
"github.com/grafana/grafana/pkg/setting"
)
var (
ErrInvalidUTF8Sequence = errors.New("invalid UTF-8 sequence")
)
var errInvalidUTF8Sequence = errors.New("invalid UTF-8 sequence")
type sectionGetter struct {
*setting.DynamicSection
@ -28,7 +26,7 @@ func (g *sectionGetter) Err() error {
func (g *sectionGetter) String(key string) string {
v := g.DynamicSection.Key(key).MustString("")
if !utf8.ValidString(v) {
g.err = fmt.Errorf("value for key %q: %w", key, ErrInvalidUTF8Sequence)
g.err = fmt.Errorf("value for key %q: %w", key, errInvalidUTF8Sequence)
return ""
}
@ -47,7 +45,7 @@ func MakeDSN(m map[string]string) (string, error) {
v := m[k]
if !utf8.ValidString(v) {
return "", fmt.Errorf("value for DSN key %q: %w", k,
ErrInvalidUTF8Sequence)
errInvalidUTF8Sequence)
}
if v == "" {
continue

View File

@ -4,8 +4,9 @@ import (
"fmt"
"testing"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/setting"
)
var invalidUTF8ByteSequence = []byte{0xff, 0xfe, 0xfd}
@ -44,7 +45,7 @@ func TestSectionGetter(t *testing.T) {
v = g.String(key)
require.Empty(t, v)
require.Error(t, g.Err())
require.ErrorIs(t, g.Err(), ErrInvalidUTF8Sequence)
require.ErrorIs(t, g.Err(), errInvalidUTF8Sequence)
}
func TestMakeDSN(t *testing.T) {
@ -55,7 +56,7 @@ func TestMakeDSN(t *testing.T) {
})
require.Empty(t, s)
require.Error(t, err)
require.ErrorIs(t, err, ErrInvalidUTF8Sequence)
require.ErrorIs(t, err, errInvalidUTF8Sequence)
s, err = MakeDSN(map[string]string{
"skip": "",

View File

@ -1,18 +1,16 @@
package migrations
import (
"context"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/setting"
)
func MigrateResourceStore(engine *xorm.Engine, cfg *setting.Cfg, features featuremgmt.FeatureToggles) error {
// Skip if feature flag is not enabled
if !features.IsEnabledGlobally(featuremgmt.FlagUnifiedStorage) {
return nil
}
func MigrateResourceStore(_ context.Context, engine *xorm.Engine, cfg *setting.Cfg) error {
// TODO: use the context.Context
mg := migrator.NewScopedMigrator(engine, cfg, "resource")
mg.AddCreateMigration()

View File

@ -3,30 +3,13 @@ package db
import (
"context"
"database/sql"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/setting"
)
const (
DriverPostgres = "postgres"
DriverMySQL = "mysql"
DriverSQLite = "sqlite"
DriverSQLite3 = "sqlite3"
)
// ResourceDBInterface provides access to a database capable of supporting the
// Entity Server.
type ResourceDBInterface interface {
Init() error
GetCfg() *setting.Cfg
GetDB() (DB, error)
// TODO: deprecate.
GetSession() (*session.SessionDB, error)
GetEngine() (*xorm.Engine, error)
// DBProvider provides access to a SQL Database.
type DBProvider interface {
// Init initializes the SQL Database, running migrations if needed. It is
// idempotent and thread-safe.
Init(context.Context) (DB, error)
}
// DB is a thin abstraction on *sql.DB to allow mocking to provide better unit

View File

@ -0,0 +1,209 @@
// Package dbutil provides utilities to perform common database operations and
// appropriate error handling.
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"text/template"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
)
// SQLError is an error returned by the database, which includes additionally
// debugging information about what was sent to the database.
type SQLError struct {
Err error
CallType string // either Query, QueryRow or Exec
TemplateName string
Query string
RawQuery string
ScanDest []any
// potentially regulated information is not exported and only directly
// available for local testing and local debugging purposes, making sure it
// is never marshaled to JSON or any other serialization.
arguments []any
}
func (e SQLError) Unwrap() error {
return e.Err
}
func (e SQLError) Error() string {
return fmt.Sprintf("%s: %s with %d input arguments and %d output "+
"destination arguments: %v; query: %s", e.TemplateName, e.CallType,
len(e.arguments), len(e.ScanDest), e.Err, e.Query)
}
// Debug provides greater detail about the SQL error. It is defined on the same
// struct but on a test file so that the intention that its results should not
// be used in runtime code is very clear. The results could include PII or
// otherwise regulated information, hence this method is only available in
// tests, so that it can be used in local debugging only. Note that the error
// information may still be available through other means, like using the
// "reflect" package, so care must be taken not to ever expose these information
// in production.
func (e SQLError) Debug() string {
scanDestStr := "(none)"
if len(e.ScanDest) > 0 {
format := "[%T" + strings.Repeat(", %T", len(e.ScanDest)-1) + "]"
scanDestStr = fmt.Sprintf(format, e.ScanDest...)
}
return fmt.Sprintf("%s: %s: %v\n\tArguments (%d): %#v\n\tReturn Value "+
"Types (%d): %s\n\tExecuted Query: %s\n\tRaw SQL Template Output: %s",
e.TemplateName, e.CallType, e.Err, len(e.arguments), e.arguments,
len(e.ScanDest), scanDestStr, e.Query, e.RawQuery)
}
// Debug is meant to provide greater debugging detail about certain errors. The
// returned error will either provide more detailed information or be the same
// original error, suitable only for local debugging. The details provided are
// not meant to be logged, since they could include PII or otherwise
// sensitive/confidential information. These information should only be used for
// local debugging with fake or otherwise non-regulated information.
func Debug(err error) error {
var d interface{ Debug() string }
if errors.As(err, &d) {
return errors.New(d.Debug())
}
return err
}
// Exec uses `req` as input for a non-data returning query generated with
// `tmpl`, and executed in `x`.
func Exec(ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.SQLTemplateIface) (sql.Result, error) {
if err := req.Validate(); err != nil {
return nil, fmt.Errorf("Exec: invalid request for template %q: %w",
tmpl.Name(), err)
}
rawQuery, err := sqltemplate.Execute(tmpl, req)
if err != nil {
return nil, fmt.Errorf("execute template: %w", err)
}
query := sqltemplate.FormatSQL(rawQuery)
res, err := x.ExecContext(ctx, query, req.GetArgs()...)
if err != nil {
return nil, SQLError{
Err: err,
CallType: "Exec",
TemplateName: tmpl.Name(),
arguments: req.GetArgs(),
Query: query,
RawQuery: rawQuery,
}
}
return res, nil
}
// Query uses `req` as input for a single-statement, set-returning query
// generated with `tmpl`, and executed in `x`. The `Results` method of `req`
// should return a deep copy since it will be used multiple times to decode each
// value. It returns an error if more than one result set is returned.
func Query[T any](ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.WithResults[T]) ([]T, error) {
if err := req.Validate(); err != nil {
return nil, fmt.Errorf("Query: invalid request for template %q: %w",
tmpl.Name(), err)
}
rawQuery, err := sqltemplate.Execute(tmpl, req)
if err != nil {
return nil, fmt.Errorf("execute template %q: %w", tmpl.Name(), err)
}
query := sqltemplate.FormatSQL(rawQuery)
rows, err := x.QueryContext(ctx, query, req.GetArgs()...)
if err != nil {
return nil, SQLError{
Err: err,
CallType: "Query",
TemplateName: tmpl.Name(),
arguments: req.GetArgs(),
ScanDest: req.GetScanDest(),
Query: query,
RawQuery: rawQuery,
}
}
var ret []T
for rows.Next() {
v, err := scanRow(rows, req)
if err != nil {
return nil, fmt.Errorf("scan value #%d: %w", len(ret)+1, err)
}
ret = append(ret, v)
}
discardedResultSets, err := DiscardRows(rows)
if err != nil {
return nil, fmt.Errorf("closing rows: %w", err)
}
if discardedResultSets > 1 {
return nil, fmt.Errorf("too many result sets: %v", discardedResultSets)
}
return ret, nil
}
// QueryRow uses `req` as input and output for a single-statement, single-row
// returning query generated with `tmpl`, and executed in `x`. It returns
// sql.ErrNoRows if no rows are returned. It also returns an error if more than
// one row or result set is returned.
func QueryRow[T any](ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.WithResults[T]) (T, error) {
var zero T
res, err := Query(ctx, x, tmpl, req)
if err != nil {
return zero, err
}
switch len(res) {
case 0:
return zero, sql.ErrNoRows
case 1:
return res[0], nil
default:
return zero, fmt.Errorf("expecting a single row, got %d", len(res))
}
}
// DiscardRows discards all the ResultSets in the given *sql.Rows and returns
// the final rows error and the number of times NextResultSet was called. This
// is useful to check for errors in queries with multiple SQL statements where
// there is no interesting output, since some drivers may omit an error returned
// by a SQL statement found in a statement that is not the first one. Note that
// not all drivers support multi-statement calls, though.
func DiscardRows(rows *sql.Rows) (int, error) {
discardedResultSets := 1
for ; rows.NextResultSet(); discardedResultSets++ {
}
return discardedResultSets, rows.Err()
}
type scanner interface {
Scan(dest ...any) error
}
// scanRow is used on *sql.Row and *sql.Rows, and is factored out here not to
// improving code reuse, but rather for ease of testing.
func scanRow[T any](sc scanner, req sqltemplate.WithResults[T]) (zero T, err error) {
if err = sc.Scan(req.GetScanDest()...); err != nil {
return zero, fmt.Errorf("row scan: %w", err)
}
res, err := req.Results()
if err != nil {
return zero, fmt.Errorf("row results: %w", err)
}
return res, nil
}

View File

@ -0,0 +1,513 @@
package dbutil
import (
"database/sql"
"errors"
"testing"
"text/template"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
sqltemplateMocks "github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate/mocks"
"github.com/grafana/grafana/pkg/storage/unified/sql/test"
"github.com/grafana/grafana/pkg/util/testutil"
)
var (
validTestTmpl = template.Must(template.New("test").Parse("nothing special"))
invalidTestTmpl = template.New("no definition should fail to exec")
errTest = errors.New("because of reasons")
)
func TestSQLError(t *testing.T) {
t.Parallel()
const hiddenMessage = "obey, consume"
var err error = SQLError{
Err: errTest,
CallType: "Exec",
TemplateName: "some.sql",
Query: "SELECT name FROM movies WHERE quote LIKE ?",
RawQuery: "SELECT name FROM movies WHERE quote LIKE ?",
ScanDest: []any{new(string)},
arguments: []any{hiddenMessage},
}
require.Error(t, err)
require.ErrorIs(t, err, errTest)
require.NotContains(t, err.Error(), hiddenMessage)
err = Debug(err)
require.Error(t, err)
require.Contains(t, err.Error(), hiddenMessage)
err = Debug(errTest)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
}
// expectRows is a testing helper to keep mocks in sync when adding rows to a
// mocked SQL result.
type expectRows[T any] struct {
*sqlmock.Rows
ExpectedResults []T
req *sqltemplateMocks.WithResults[T]
}
func newReturnsRow[T any](dbmock sqlmock.Sqlmock, req *sqltemplateMocks.WithResults[T]) *expectRows[T] {
return &expectRows[T]{
Rows: dbmock.NewRows(nil),
req: req,
}
}
// Add adds a new value that should be returned by the `Query` or `QueryRow`
// operation.
func (r *expectRows[T]) Add(value T, err error) *expectRows[T] {
r.req.EXPECT().GetScanDest().Return(nil).Once()
r.req.EXPECT().Results().Return(value, err).Once()
r.Rows.AddRow()
r.ExpectedResults = append(r.ExpectedResults, value)
return r
}
func TestQuery(t *testing.T) {
t.Parallel()
t.Run("happy path - no rows returned", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil)
req.EXPECT().GetScanDest().Return(nil).Maybe()
rdb.SQLMock.ExpectQuery("").WillReturnRows(rdb.SQLMock.NewRows(nil))
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.NoError(t, err)
require.Zero(t, res)
})
t.Run("happy path - multiple rows returned", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows.Add(1, nil)
rows.Add(2, nil)
rows.Add(3, nil)
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows.Rows)
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.NoError(t, err)
require.NotZero(t, res)
require.Equal(t, rows.ExpectedResults, res)
})
t.Run("invalid request", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(errTest).Once()
// execute and assert
res, err := Query(ctx, rdb.DB, invalidTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "invalid request")
})
t.Run("error executing template", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
// execute and assert
res, err := Query(ctx, rdb.DB, invalidTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "execute template")
})
t.Run("error executing query", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil)
req.EXPECT().GetScanDest().Return(nil).Maybe()
rdb.SQLMock.ExpectQuery("").WillReturnError(errTest)
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorAs(t, err, new(SQLError))
})
t.Run("error decoding row", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows.Add(0, errTest)
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows.Rows)
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "scan value")
})
t.Run("error iterating rows", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows.Rows.AddRow() // we don't expect GetScanDest or Results here
rows.Rows.RowError(0, errTest)
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows.Rows)
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "closing rows")
})
t.Run("too many result sets", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows1 := newReturnsRow(rdb.SQLMock, req)
rows2 := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows1.Add(1, nil)
rows2.Rows.AddRow() // we don't expect GetScanDest or Results here
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows1.Rows, rows2.Rows)
// execute and assert
res, err := Query(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "too many result sets")
})
}
func TestQueryRow(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows.Add(1, nil)
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows.Rows)
// execute and assert
res, err := QueryRow(ctx, rdb.DB, validTestTmpl, req)
require.NoError(t, err)
require.Equal(t, rows.ExpectedResults[0], res)
})
t.Run("no rows returned", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rdb.SQLMock.ExpectQuery("").WillReturnRows(rdb.SQLMock.NewRows(nil))
// execute and assert
res, err := QueryRow(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorIs(t, err, sql.ErrNoRows)
})
t.Run("error executing query", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil)
req.EXPECT().GetScanDest().Return(nil).Maybe()
rdb.SQLMock.ExpectQuery("").WillReturnError(errTest)
// execute and assert
res, err := QueryRow(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorAs(t, err, new(SQLError))
})
t.Run("too many rows returned", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows.Add(1, nil)
rows.Add(2, nil)
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows.Rows)
// execute and assert
res, err := QueryRow(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "expecting a single row")
})
t.Run("too many result sets", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewWithResults[int64](t)
rdb := test.NewDBProviderNopSQL(t)
rows1 := newReturnsRow(rdb.SQLMock, req)
rows2 := newReturnsRow(rdb.SQLMock, req)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rows1.Add(1, nil)
rows2.Rows.AddRow() // we don't expect GetScanDest or Results here
rdb.SQLMock.ExpectQuery("").WillReturnRows(rows1.Rows, rows2.Rows)
// execute and assert
res, err := QueryRow(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "too many result sets")
})
}
// scannerFunc is an adapter for the `scanner` interface.
type scannerFunc func(dest ...any) error
func (f scannerFunc) Scan(dest ...any) error {
return f(dest...)
}
func TestScanRow(t *testing.T) {
t.Parallel()
const value int64 = 1
t.Run("happy path", func(t *testing.T) {
t.Parallel()
// test declarations
req := sqltemplateMocks.NewWithResults[int64](t)
sc := scannerFunc(func(dest ...any) error {
return nil
})
// setup expectations
req.EXPECT().GetScanDest().Return(nil).Once()
req.EXPECT().Results().Return(value, nil).Once()
// execute and assert
res, err := scanRow(sc, req)
require.NoError(t, err)
require.Equal(t, value, res)
})
t.Run("scan error", func(t *testing.T) {
t.Parallel()
// test declarations
req := sqltemplateMocks.NewWithResults[int64](t)
sc := scannerFunc(func(dest ...any) error {
return errTest
})
// setup expectations
req.EXPECT().GetScanDest().Return(nil).Once()
// execute and assert
res, err := scanRow(sc, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
t.Run("results error", func(t *testing.T) {
t.Parallel()
// test declarations
req := sqltemplateMocks.NewWithResults[int64](t)
sc := scannerFunc(func(dest ...any) error {
return nil
})
// setup expectations
req.EXPECT().GetScanDest().Return(nil).Once()
req.EXPECT().Results().Return(0, errTest).Once()
// execute and assert
res, err := scanRow(sc, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
}
func TestExec(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewSQLTemplateIface(t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil).Once()
rdb.SQLMock.ExpectExec("").WillReturnResult(sqlmock.NewResult(0, 0))
// execute and assert
res, err := Exec(ctx, rdb.DB, validTestTmpl, req)
require.NoError(t, err)
require.NotZero(t, res)
})
t.Run("invalid request", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewSQLTemplateIface(t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(errTest).Once()
// execute and assert
res, err := Exec(ctx, rdb.DB, invalidTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "invalid request")
})
t.Run("error executing template", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewSQLTemplateIface(t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
// execute and assert
res, err := Exec(ctx, rdb.DB, invalidTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorContains(t, err, "execute template")
})
t.Run("error executing SQL", func(t *testing.T) {
t.Parallel()
// test declarations
ctx := testutil.NewDefaultTestContext(t)
req := sqltemplateMocks.NewSQLTemplateIface(t)
rdb := test.NewDBProviderNopSQL(t)
// setup expectations
req.EXPECT().Validate().Return(nil).Once()
req.EXPECT().GetArgs().Return(nil)
rdb.SQLMock.ExpectExec("").WillReturnError(errTest)
// execute and assert
res, err := Exec(ctx, rdb.DB, validTestTmpl, req)
require.Zero(t, res)
require.Error(t, err)
require.ErrorAs(t, err, new(SQLError))
})
}

View File

@ -57,33 +57,6 @@ var (
}
)
// SQLError is an error returned by the database, which includes additionally
// debugging information about what was sent to the database.
type SQLError struct {
Err error
CallType string // either Query, QueryRow or Exec
TemplateName string
Query string
RawQuery string
ScanDest []any
// potentially regulated information is not exported and only directly
// available for local testing and local debugging purposes, making sure it
// is never marshaled to JSON or any other serialization.
arguments []any
}
func (e SQLError) Unwrap() error {
return e.Err
}
func (e SQLError) Error() string {
return fmt.Sprintf("%s: %s with %d input arguments and %d output "+
"destination arguments: %v", e.TemplateName, e.CallType,
len(e.arguments), len(e.ScanDest), e.Err)
}
type sqlResourceRequest struct {
*sqltemplate.SQLTemplate
GUID string
@ -149,6 +122,18 @@ func (r sqlResourceListRequest) Validate() error {
return nil // TODO
}
func (r sqlResourceListRequest) Results() (*resource.ResourceWrapper, error) {
// sqlResourceListRequest is a set-returning query. As such, it
// should not return its *Response, since that will be overwritten in the
// next call to `Scan`, so it needs to return a copy of it. Note, though,
// that it is safe to return the same `Response.Value` since `Scan`
// allocates a new slice of bytes each time.
return &resource.ResourceWrapper{
ResourceVersion: r.Response.ResourceVersion,
Value: r.Response.Value,
}, nil
}
type historyListRequest struct {
ResourceVersion, Limit, Offset int64
Options *resource.ListOptions
@ -163,6 +148,18 @@ func (r sqlResourceHistoryListRequest) Validate() error {
return nil // TODO
}
func (r sqlResourceHistoryListRequest) Results() (*resource.ResourceWrapper, error) {
// sqlResourceHistoryListRequest is a set-returning query. As such, it
// should not return its *Response, since that will be overwritten in the
// next call to `Scan`, so it needs to return a copy of it. Note, though,
// that it is safe to return the same `Response.Value` since `Scan`
// allocates a new slice of bytes each time.
return &resource.ResourceWrapper{
ResourceVersion: r.Response.ResourceVersion,
Value: r.Response.Value,
}, nil
}
// update RV
type sqlResourceUpdateRVRequest struct {

View File

@ -2,7 +2,6 @@ package sql
import (
"embed"
"errors"
"testing"
"text/template"
@ -12,23 +11,6 @@ import (
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
)
// debug is meant to provide greater debugging detail about certain errors. The
// returned error will either provide more detailed information or be the same
// original error, suitable only for local debugging. The details provided are
// not meant to be logged, since they could include PII or otherwise
// sensitive/confidential information. These information should only be used for
// local debugging with fake or otherwise non-regulated information.
func debug(err error) error {
var d interface{ Debug() string }
if errors.As(err, &d) {
return errors.New(d.Debug())
}
return err
}
var _ = debug // silence the `unused` linter
//go:embed testdata/*
var testdataFS embed.FS
@ -126,6 +108,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceUpdate: {
{
Name: "single path",
@ -162,6 +145,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceList: {
{
Name: "filter on namespace",
@ -185,6 +169,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceHistoryList: {
{
Name: "single path",
@ -208,6 +193,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceUpdateRV: {
{
Name: "single path",
@ -222,6 +208,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceHistoryRead: {
{
Name: "single path",
@ -241,6 +228,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceHistoryUpdateRV: {
{
Name: "single path",
@ -255,6 +243,7 @@ func TestQueries(t *testing.T) {
},
},
},
sqlResourceHistoryInsert: {
{
Name: "insert into resource_history",
@ -343,7 +332,7 @@ func TestQueries(t *testing.T) {
expectedQuery := sqltemplate.FormatSQL(rawQuery)
for _, d := range ds {
t.Run(d.Name(), func(t *testing.T) {
t.Run(d.DialectName(), func(t *testing.T) {
// not parallel for the same reason
tc.Data.SetDialect(d)

View File

@ -1,8 +1,9 @@
package sql
import (
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/tracing"
"go.opentelemetry.io/otel/trace"
infraDB "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/storage/unified/resource"
@ -10,7 +11,7 @@ import (
)
// Creates a ResourceServer
func ProvideResourceServer(db db.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer) (resource.ResourceServer, error) {
func ProvideResourceServer(db infraDB.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer trace.Tracer) (resource.ResourceServer, error) {
opts := resource.ResourceServerOptions{
Tracer: tracer,
}
@ -19,7 +20,7 @@ func ProvideResourceServer(db db.DB, cfg *setting.Cfg, features featuremgmt.Feat
if err != nil {
return nil, err
}
store, err := NewBackendStore(backendOptions{DB: eDB, Tracer: tracer})
store, err := NewBackend(BackendOptions{DBProvider: eDB, Tracer: tracer})
if err != nil {
return nil, err
}

View File

@ -12,6 +12,21 @@ var (
ErrInvalidRowLockingClause = errors.New("invalid row-locking clause")
)
// DialectForDriver returns a predefined Dialect for the given driver name, or
// nil if no Dialect is known for that driver.
func DialectForDriver(driverName string) Dialect {
switch strings.ToLower(driverName) {
case "mysql":
return MySQL
case "postgres", "pgx":
return PostgreSQL
case "sqlite", "sqlite3":
return SQLite
default:
return nil
}
}
// Dialect should be added to the data types passed to SQL templates to
// provide methods that deal with SQL implementation-specific traits. It can be
// embedded for ease of use, or with a named struct field if any of its methods
@ -21,7 +36,7 @@ type Dialect interface {
// than one DBMS (e.g. "postgres" is common to PostgreSQL and to
// CockroachDB), while we can maintain different Dialects for the same DBMS
// but different versions (e.g. "mysql5" and "mysql8").
Name() string
DialectName() string
// Ident returns the given string quoted in a way that is suitable to be
// used as an identifier. Database names, schema names, table names, column
@ -135,6 +150,6 @@ var (
type name string
func (n name) Name() string {
func (n name) DialectName() string {
return string(n)
}

View File

@ -178,7 +178,7 @@ func TestName_Name(t *testing.T) {
const v = "some dialect name"
n := name(v)
if n.Name() != v {
t.Fatalf("unexpected dialect name %q", n.Name())
if n.DialectName() != v {
t.Fatalf("unexpected dialect name %q", n.DialectName())
}
}

View File

@ -171,6 +171,51 @@ func (_c *SQLTemplateIface_ArgPlaceholder_Call) RunAndReturn(run func(int) strin
return _c
}
// DialectName provides a mock function with given fields:
func (_m *SQLTemplateIface) DialectName() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for DialectName")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// SQLTemplateIface_DialectName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DialectName'
type SQLTemplateIface_DialectName_Call struct {
*mock.Call
}
// DialectName is a helper method to define mock.On call
func (_e *SQLTemplateIface_Expecter) DialectName() *SQLTemplateIface_DialectName_Call {
return &SQLTemplateIface_DialectName_Call{Call: _e.mock.On("DialectName")}
}
func (_c *SQLTemplateIface_DialectName_Call) Run(run func()) *SQLTemplateIface_DialectName_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *SQLTemplateIface_DialectName_Call) Return(_a0 string) *SQLTemplateIface_DialectName_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *SQLTemplateIface_DialectName_Call) RunAndReturn(run func() string) *SQLTemplateIface_DialectName_Call {
_c.Call.Return(run)
return _c
}
// GetArgs provides a mock function with given fields:
func (_m *SQLTemplateIface) GetArgs() []interface{} {
ret := _m.Called()
@ -425,51 +470,6 @@ func (_c *SQLTemplateIface_Into_Call) RunAndReturn(run func(reflect.Value, strin
return _c
}
// Name provides a mock function with given fields:
func (_m *SQLTemplateIface) Name() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Name")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// SQLTemplateIface_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name'
type SQLTemplateIface_Name_Call struct {
*mock.Call
}
// Name is a helper method to define mock.On call
func (_e *SQLTemplateIface_Expecter) Name() *SQLTemplateIface_Name_Call {
return &SQLTemplateIface_Name_Call{Call: _e.mock.On("Name")}
}
func (_c *SQLTemplateIface_Name_Call) Run(run func()) *SQLTemplateIface_Name_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *SQLTemplateIface_Name_Call) Return(_a0 string) *SQLTemplateIface_Name_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *SQLTemplateIface_Name_Call) RunAndReturn(run func() string) *SQLTemplateIface_Name_Call {
_c.Call.Return(run)
return _c
}
// Reset provides a mock function with given fields:
func (_m *SQLTemplateIface) Reset() {
_m.Called()

View File

@ -171,6 +171,51 @@ func (_c *WithResults_ArgPlaceholder_Call[T]) RunAndReturn(run func(int) string)
return _c
}
// DialectName provides a mock function with given fields:
func (_m *WithResults[T]) DialectName() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for DialectName")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// WithResults_DialectName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DialectName'
type WithResults_DialectName_Call[T interface{}] struct {
*mock.Call
}
// DialectName is a helper method to define mock.On call
func (_e *WithResults_Expecter[T]) DialectName() *WithResults_DialectName_Call[T] {
return &WithResults_DialectName_Call[T]{Call: _e.mock.On("DialectName")}
}
func (_c *WithResults_DialectName_Call[T]) Run(run func()) *WithResults_DialectName_Call[T] {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *WithResults_DialectName_Call[T]) Return(_a0 string) *WithResults_DialectName_Call[T] {
_c.Call.Return(_a0)
return _c
}
func (_c *WithResults_DialectName_Call[T]) RunAndReturn(run func() string) *WithResults_DialectName_Call[T] {
_c.Call.Return(run)
return _c
}
// GetArgs provides a mock function with given fields:
func (_m *WithResults[T]) GetArgs() []interface{} {
ret := _m.Called()
@ -425,51 +470,6 @@ func (_c *WithResults_Into_Call[T]) RunAndReturn(run func(reflect.Value, string)
return _c
}
// Name provides a mock function with given fields:
func (_m *WithResults[T]) Name() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Name")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// WithResults_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name'
type WithResults_Name_Call[T interface{}] struct {
*mock.Call
}
// Name is a helper method to define mock.On call
func (_e *WithResults_Expecter[T]) Name() *WithResults_Name_Call[T] {
return &WithResults_Name_Call[T]{Call: _e.mock.On("Name")}
}
func (_c *WithResults_Name_Call[T]) Run(run func()) *WithResults_Name_Call[T] {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *WithResults_Name_Call[T]) Return(_a0 string) *WithResults_Name_Call[T] {
_c.Call.Return(_a0)
return _c
}
func (_c *WithResults_Name_Call[T]) RunAndReturn(run func() string) *WithResults_Name_Call[T] {
_c.Call.Return(run)
return _c
}
// Reset provides a mock function with given fields:
func (_m *WithResults[T]) Reset() {
_m.Called()

View File

@ -0,0 +1,296 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
infraDB "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/storage/unified/resource"
"github.com/grafana/grafana/pkg/storage/unified/sql"
"github.com/grafana/grafana/pkg/storage/unified/sql/db/dbimpl"
"github.com/grafana/grafana/pkg/tests/testsuite"
"github.com/grafana/grafana/pkg/util/testutil"
)
func TestMain(m *testing.M) {
testsuite.Run(m)
}
func newServer(t *testing.T) sql.Backend {
t.Helper()
dbstore := infraDB.InitTestDB(t)
cfg := setting.NewCfg()
features := featuremgmt.WithFeatures(featuremgmt.FlagUnifiedStorage)
tr := noop.NewTracerProvider().Tracer("integrationtests")
eDB, err := dbimpl.ProvideResourceDB(dbstore, cfg, features, tr)
require.NoError(t, err)
require.NotNil(t, eDB)
ret, err := sql.NewBackend(sql.BackendOptions{
DBProvider: eDB,
Tracer: tr,
})
require.NoError(t, err)
require.NotNil(t, ret)
err = ret.Init(testutil.NewDefaultTestContext(t))
require.NoError(t, err)
return ret
}
func TestBackendHappyPath(t *testing.T) {
ctx := testutil.NewDefaultTestContext(t)
store := newServer(t)
stream, err := store.WatchWriteEvents(ctx)
assert.NoError(t, err)
t.Run("Add 3 resources", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(1), rv)
rv, err = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(2), rv)
rv, err = writeEvent(ctx, store, "item3", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, int64(3), rv)
})
t.Run("Update item2", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item2", resource.WatchEvent_MODIFIED)
assert.NoError(t, err)
assert.Equal(t, int64(4), rv)
})
t.Run("Delete item1", func(t *testing.T) {
rv, err := writeEvent(ctx, store, "item1", resource.WatchEvent_DELETED)
assert.NoError(t, err)
assert.Equal(t, int64(5), rv)
})
t.Run("Read latest item 2", func(t *testing.T) {
resp, err := store.Read(ctx, &resource.ReadRequest{Key: resourceKey("item2")})
assert.NoError(t, err)
assert.Equal(t, int64(4), resp.ResourceVersion)
assert.Equal(t, "item2 MODIFIED", string(resp.Value))
})
t.Run("Read early verion of item2", func(t *testing.T) {
resp, err := store.Read(ctx, &resource.ReadRequest{
Key: resourceKey("item2"),
ResourceVersion: 3, // item2 was created at rv=2 and updated at rv=4
})
assert.NoError(t, err)
assert.Equal(t, int64(2), resp.ResourceVersion)
assert.Equal(t, "item2 ADDED", string(resp.Value))
})
t.Run("PrepareList latest", func(t *testing.T) {
resp, err := store.PrepareList(ctx, &resource.ListRequest{
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, resp.Items, 2)
assert.Equal(t, "item2 MODIFIED", string(resp.Items[0].Value))
assert.Equal(t, "item3 ADDED", string(resp.Items[1].Value))
assert.Equal(t, int64(5), resp.ResourceVersion)
})
t.Run("Watch events", func(t *testing.T) {
event := <-stream
assert.Equal(t, "item1", event.Key.Name)
assert.Equal(t, int64(1), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item2", event.Key.Name)
assert.Equal(t, int64(2), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item3", event.Key.Name)
assert.Equal(t, int64(3), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_ADDED, event.Type)
event = <-stream
assert.Equal(t, "item2", event.Key.Name)
assert.Equal(t, int64(4), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_MODIFIED, event.Type)
event = <-stream
assert.Equal(t, "item1", event.Key.Name)
assert.Equal(t, int64(5), event.ResourceVersion)
assert.Equal(t, resource.WatchEvent_DELETED, event.Type)
})
}
func TestBackendWatchWriteEventsFromLastest(t *testing.T) {
ctx := testutil.NewDefaultTestContext(t)
store := newServer(t)
// Create a few resources before initing the watch
_, err := writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED)
assert.NoError(t, err)
// Start the watch
stream, err := store.WatchWriteEvents(ctx)
assert.NoError(t, err)
// Create one more event
_, err = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED)
assert.NoError(t, err)
assert.Equal(t, "item2", (<-stream).Key.Name)
}
func TestBackendPrepareList(t *testing.T) {
ctx := testutil.NewDefaultTestContext(t)
store := newServer(t)
// Create a few resources before initing the watch
_, _ = writeEvent(ctx, store, "item1", resource.WatchEvent_ADDED) // rv=1
_, _ = writeEvent(ctx, store, "item2", resource.WatchEvent_ADDED) // rv=2 - will be modified at rv=6
_, _ = writeEvent(ctx, store, "item3", resource.WatchEvent_ADDED) // rv=3 - will be deleted at rv=7
_, _ = writeEvent(ctx, store, "item4", resource.WatchEvent_ADDED) // rv=4
_, _ = writeEvent(ctx, store, "item5", resource.WatchEvent_ADDED) // rv=5
_, _ = writeEvent(ctx, store, "item2", resource.WatchEvent_MODIFIED) // rv=6
_, _ = writeEvent(ctx, store, "item3", resource.WatchEvent_DELETED) // rv=7
_, _ = writeEvent(ctx, store, "item6", resource.WatchEvent_ADDED) // rv=8
t.Run("fetch all latest", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 5)
assert.Empty(t, res.NextPageToken)
})
t.Run("list latest first page ", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Limit: 3,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 3)
continueToken, err := sql.GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(8), continueToken.ResourceVersion)
assert.Equal(t, int64(3), continueToken.StartOffset)
})
t.Run("list at revision", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
ResourceVersion: 4,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 4)
assert.Equal(t, "item1 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item2 ADDED", string(res.Items[1].Value))
assert.Equal(t, "item3 ADDED", string(res.Items[2].Value))
assert.Equal(t, "item4 ADDED", string(res.Items[3].Value))
assert.Empty(t, res.NextPageToken)
})
t.Run("fetch first page at revision with limit", func(t *testing.T) {
res, err := store.PrepareList(ctx, &resource.ListRequest{
Limit: 3,
ResourceVersion: 7,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 3)
assert.Equal(t, "item1 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item4 ADDED", string(res.Items[1].Value))
assert.Equal(t, "item5 ADDED", string(res.Items[2].Value))
continueToken, err := sql.GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(7), continueToken.ResourceVersion)
assert.Equal(t, int64(3), continueToken.StartOffset)
})
t.Run("fetch second page at revision", func(t *testing.T) {
continueToken := &sql.ContinueToken{
ResourceVersion: 8,
StartOffset: 2,
}
res, err := store.PrepareList(ctx, &resource.ListRequest{
NextPageToken: continueToken.String(),
Limit: 2,
Options: &resource.ListOptions{
Key: &resource.ResourceKey{
Group: "group",
Resource: "resource",
},
},
})
assert.NoError(t, err)
assert.Len(t, res.Items, 2)
assert.Equal(t, "item5 ADDED", string(res.Items[0].Value))
assert.Equal(t, "item2 MODIFIED", string(res.Items[1].Value))
continueToken, err = sql.GetContinueToken(res.NextPageToken)
assert.NoError(t, err)
assert.Equal(t, int64(8), continueToken.ResourceVersion)
assert.Equal(t, int64(4), continueToken.StartOffset)
})
}
func writeEvent(ctx context.Context, store sql.Backend, name string, action resource.WatchEvent_Type) (int64, error) {
return store.WriteEvent(ctx, resource.WriteEvent{
Type: action,
Value: []byte(name + " " + resource.WatchEvent_Type_name[int32(action)]),
Key: &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
Name: name,
},
})
}
func resourceKey(name string) *resource.ResourceKey {
return &resource.ResourceKey{
Namespace: "namespace",
Group: "group",
Resource: "resource",
Name: name,
}
}

View File

@ -0,0 +1,128 @@
package test
import (
"context"
"fmt"
"regexp"
"strings"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db/dbimpl"
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
)
// TestDBProvider is a stub for db.ResourceDBInterface.
type TestDBProvider struct {
Err error
DB db.DB
SQLMock sqlmock.Sqlmock
}
func (d TestDBProvider) Init(context.Context) (db.DB, error) {
return d.DB, d.Err
}
var _ db.DBProvider = TestDBProvider{}
// NewDBProviderNopSQL returns a TestDBProvider with a sqlmock.Sqlmock that
// doesn't validates SQL. This is only meant to be used to test wrapping
// utilities where the actual SQL is not relevant to the unit tests, but rather
// how the possible derived error conditions handled.
func NewDBProviderNopSQL(t *testing.T) TestDBProvider {
t.Helper()
mockDB, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherFunc(
func(string, string) error { return nil },
)),
)
require.NoError(t, err)
return TestDBProvider{
DB: dbimpl.NewDB(mockDB, "mysql"),
SQLMock: mock,
}
}
// NewDBProviderWithPing requires that database pings have a matching
// expectation, which are ignored by default. The SQL matching is the sqlmock
// default.
func NewDBProviderWithPing(t *testing.T) TestDBProvider {
t.Helper()
mockDB, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
return TestDBProvider{
DB: dbimpl.NewDB(mockDB, "mysql"),
SQLMock: mock,
}
}
// NewDBProviderMatchWords returns a TestDBProvider with a sqlmock.Sqlmock that
// will match SQL by splitting the expected SQL string into words, and then try
// to find all of them in the actual SQL, in the given order, case
// insensitively. Prepend a word with a `!` to say that word should not be
// found.
func NewDBProviderMatchWords(t *testing.T) TestDBProvider {
t.Helper()
mockDB, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(
sqlmock.QueryMatcherFunc(wordsMatcherFunc),
),
)
require.NoError(t, err)
return TestDBProvider{
DB: dbimpl.NewDB(mockDB, "mysql"),
SQLMock: mock,
}
}
func wordsMatcherFunc(expectedSQL, actualSQL string) error {
actualSQL = strings.ToLower(sqltemplate.FormatSQL(actualSQL))
expectedSQL = strings.ToLower(expectedSQL)
var offset int
for _, vv := range matchWorsRE.FindAllStringSubmatch(expectedSQL, -1) {
v := vv[1]
var shouldNotMatch bool
if v != "" && v[0] == '!' {
v = v[1:]
shouldNotMatch = true
}
if v == "" {
return fmt.Errorf("invalid expected word %q in %q", v,
expectedSQL)
}
reWord, err := regexp.Compile(`\b` + regexp.QuoteMeta(v) + `\b`)
if err != nil {
return fmt.Errorf("compile word %q from expected SQL: %s", v,
expectedSQL)
}
if shouldNotMatch {
if reWord.MatchString(actualSQL[offset:]) {
return fmt.Errorf("actual SQL fragent should not cont"+
"ain %q but it does\n\tFragment: %s\n\tFull SQL: %s",
v, actualSQL[offset:], actualSQL)
}
} else {
loc := reWord.FindStringIndex(actualSQL[offset:])
if len(loc) == 0 {
return fmt.Errorf("actual SQL fragment should contain "+
"%q but it doesn't\n\tFragment: %s\n\tFull SQL: %s",
v, actualSQL[offset:], actualSQL)
}
offset = loc[1] // advance the offset
}
}
return nil
}
var matchWorsRE = regexp.MustCompile(`(?:\W|\A)(!?\w+)\b`)