From 8143610024ef01729a850659072bab83fe5694c1 Mon Sep 17 00:00:00 2001 From: bergquist Date: Tue, 5 Jun 2018 21:13:53 +0200 Subject: [PATCH] bus: support multiple dispatch in one transaction this makes it possible to run multiple DispatchCtx in one transaction. The TransactionManager will start/end the transaction and pass the dbsession in the context.Context variable --- pkg/bus/bus.go | 50 +++++++++++++++++++++++++ pkg/services/alerting/notifiers/base.go | 1 + pkg/services/sqlstore/shared.go | 28 +++++++++++++- pkg/services/sqlstore/sqlstore.go | 49 +++++++++++++++++++++++- pkg/services/sqlstore/stats.go | 18 +++++++++ 5 files changed, 144 insertions(+), 2 deletions(-) diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 32a591b6672..98279678777 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -12,21 +12,51 @@ type Msg interface{} var ErrHandlerNotFound = errors.New("handler not found") +type TransactionManager interface { + Begin(ctx context.Context) (context.Context, error) + End(ctx context.Context, err error) error +} + type Bus interface { Dispatch(msg Msg) error DispatchCtx(ctx context.Context, msg Msg) error Publish(msg Msg) error + // InTransaction starts a transaction and store it in the context. + // The caller can then pass a function with multiple DispatchCtx calls that + // all will be executed in the same transaction. InTransaction will rollback if the + // callback returns an error.s + InTransaction(ctx context.Context, fn func(ctx context.Context) error) error + AddHandler(handler HandlerFunc) AddCtxHandler(handler HandlerFunc) AddEventListener(handler HandlerFunc) AddWildcardListener(handler HandlerFunc) + + // SetTransactionManager allows the user to replace the internal + // noop TransactionManager that is responsible for manageing + // transactions in `InTransaction` + SetTransactionManager(tm TransactionManager) +} + +func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error { + ctxWithTran, err := b.transactionManager.Begin(ctx) + if err != nil { + return err + } + + err = fn(ctxWithTran) + b.transactionManager.End(ctxWithTran, err) + + return err } type InProcBus struct { handlers map[string]HandlerFunc listeners map[string][]HandlerFunc wildcardListeners []HandlerFunc + + transactionManager TransactionManager } // temp stuff, not sure how to handle bus instance, and init yet @@ -37,6 +67,9 @@ func New() Bus { bus.handlers = make(map[string]HandlerFunc) bus.listeners = make(map[string][]HandlerFunc) bus.wildcardListeners = make([]HandlerFunc, 0) + + bus.transactionManager = &NoopTransactionManager{} + return bus } @@ -45,6 +78,14 @@ func GetBus() Bus { return globalBus } +func SetTransactionManager(tm TransactionManager) { + globalBus.SetTransactionManager(tm) +} + +func (b *InProcBus) SetTransactionManager(tm TransactionManager) { + b.transactionManager = tm +} + func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error { var msgName = reflect.TypeOf(msg).Elem().Name() @@ -167,6 +208,15 @@ func Publish(msg Msg) error { return globalBus.Publish(msg) } +func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error { + return globalBus.InTransaction(ctx, fn) +} + func ClearBusHandlers() { globalBus = New() } + +type NoopTransactionManager struct{} + +func (*NoopTransactionManager) Begin(ctx context.Context) (context.Context, error) { return ctx, nil } +func (*NoopTransactionManager) End(ctx context.Context, err error) error { return err } diff --git a/pkg/services/alerting/notifiers/base.go b/pkg/services/alerting/notifiers/base.go index 51676efdfd5..868db3aec79 100644 --- a/pkg/services/alerting/notifiers/base.go +++ b/pkg/services/alerting/notifiers/base.go @@ -3,6 +3,7 @@ package notifiers import ( "github.com/grafana/grafana/pkg/components/simplejson" m "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/alerting" ) diff --git a/pkg/services/sqlstore/shared.go b/pkg/services/sqlstore/shared.go index 9a24a513aad..3ccb92f010f 100644 --- a/pkg/services/sqlstore/shared.go +++ b/pkg/services/sqlstore/shared.go @@ -1,6 +1,7 @@ package sqlstore import ( + "context" "reflect" "time" @@ -29,10 +30,35 @@ func inTransaction(callback dbTransactionFunc) error { return inTransactionWithRetry(callback, 0) } +func startSession(ctx context.Context) *DBSession { + value := ctx.Value(ContextSessionName) + var sess *xorm.Session + sess, ok := value.(*xorm.Session) + + if !ok { + return newSession() + } + + old := newSession() + old.Session = sess + + return old +} + +func withDbSession(ctx context.Context, callback dbTransactionFunc) error { + sess := startSession(ctx) + + return callback(sess) +} + func inTransactionWithRetry(callback dbTransactionFunc, retry int) error { + return inTransactionWithRetryCtx(context.Background(), callback, retry) +} + +func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error { var err error - sess := newSession() + sess := startSession(ctx) defer sess.Close() if err = sess.Begin(); err != nil { diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index ed82829665f..bfe462f4d91 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -1,6 +1,8 @@ package sqlstore import ( + "context" + "errors" "fmt" "net/url" "os" @@ -35,6 +37,8 @@ var ( sqlog log.Logger = log.New("sqlstore") ) +const ContextSessionName = "db-session" + func init() { registry.Register(®istry.Descriptor{ Name: "SqlStore", @@ -45,6 +49,7 @@ func init() { type SqlStore struct { Cfg *setting.Cfg `inject:""` + Bus bus.Bus `inject:""` dbCfg DatabaseConfig engine *xorm.Engine @@ -77,6 +82,10 @@ func (ss *SqlStore) Init() error { // Init repo instances annotations.SetRepository(&SqlAnnotationRepo{}) + ss.Bus.SetTransactionManager(&SQLTransactionManager{ + engine: ss.engine, + }) + // ensure admin user if ss.skipEnsureAdmin { return nil @@ -85,10 +94,47 @@ func (ss *SqlStore) Init() error { return ss.ensureAdminUser() } +type SQLTransactionManager struct { + engine *xorm.Engine +} + +func (stm *SQLTransactionManager) Begin(ctx context.Context) (context.Context, error) { + sess := stm.engine.NewSession() + err := sess.Begin() + if err != nil { + return ctx, err + } + + withValue := context.WithValue(ctx, ContextSessionName, sess) + + return withValue, nil +} + +func (stm *SQLTransactionManager) End(ctx context.Context, err error) error { + value := ctx.Value(ContextSessionName) + sess, ok := value.(*xorm.Session) + if !ok { + return errors.New("context is missing transaction") + } + + if err != nil { + sess.Rollback() + return err + } + + defer sess.Close() + + return sess.Commit() +} + func (ss *SqlStore) ensureAdminUser() error { systemUserCountQuery := m.GetSystemUserCountStatsQuery{} - if err := bus.Dispatch(&systemUserCountQuery); err != nil { + err := bus.InTransaction(context.Background(), func(ctx context.Context) error { + return bus.DispatchCtx(ctx, &systemUserCountQuery) + }) + + if err != nil { return fmt.Errorf("Could not determine if admin user exists: %v", err) } @@ -240,6 +286,7 @@ func (ss *SqlStore) readConfig() { func InitTestDB(t *testing.T) *SqlStore { sqlstore := &SqlStore{} sqlstore.skipEnsureAdmin = true + sqlstore.Bus = bus.New() dbType := migrator.SQLITE diff --git a/pkg/services/sqlstore/stats.go b/pkg/services/sqlstore/stats.go index 3e3e83c4014..af4482d9e25 100644 --- a/pkg/services/sqlstore/stats.go +++ b/pkg/services/sqlstore/stats.go @@ -1,6 +1,7 @@ package sqlstore import ( + "context" "time" "github.com/grafana/grafana/pkg/bus" @@ -13,6 +14,7 @@ func init() { bus.AddHandler("sql", GetDataSourceAccessStats) bus.AddHandler("sql", GetAdminStats) bus.AddHandler("sql", GetSystemUserCountStats) + bus.AddCtxHandler("sql", GetSystemUserCountStatsCtx) } var activeUserTimeLimit = time.Hour * 24 * 30 @@ -133,6 +135,22 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error { return err } +func GetSystemUserCountStatsCtx(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error { + return withDbSession(ctx, func(sess *DBSession) error { + + var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user") + var stats m.SystemUserCountStats + _, err := sess.SQL(rawSql).Get(&stats) + if err != nil { + return err + } + + query.Result = &stats + + return err + }) +} + func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error { var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user") var stats m.SystemUserCountStats