bus: support multiple dispatch in one transaction

this makes it possible to run multiple DispatchCtx
in one transaction. The TransactionManager will
start/end the transaction and pass the dbsession
in the context.Context variable
This commit is contained in:
bergquist 2018-06-05 21:13:53 +02:00
parent e33d18701d
commit 8143610024
5 changed files with 144 additions and 2 deletions

View File

@ -12,21 +12,51 @@ type Msg interface{}
var ErrHandlerNotFound = errors.New("handler not found") var ErrHandlerNotFound = errors.New("handler not found")
type TransactionManager interface {
Begin(ctx context.Context) (context.Context, error)
End(ctx context.Context, err error) error
}
type Bus interface { type Bus interface {
Dispatch(msg Msg) error Dispatch(msg Msg) error
DispatchCtx(ctx context.Context, msg Msg) error DispatchCtx(ctx context.Context, msg Msg) error
Publish(msg Msg) error Publish(msg Msg) error
// InTransaction starts a transaction and store it in the context.
// The caller can then pass a function with multiple DispatchCtx calls that
// all will be executed in the same transaction. InTransaction will rollback if the
// callback returns an error.s
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
AddHandler(handler HandlerFunc) AddHandler(handler HandlerFunc)
AddCtxHandler(handler HandlerFunc) AddCtxHandler(handler HandlerFunc)
AddEventListener(handler HandlerFunc) AddEventListener(handler HandlerFunc)
AddWildcardListener(handler HandlerFunc) AddWildcardListener(handler HandlerFunc)
// SetTransactionManager allows the user to replace the internal
// noop TransactionManager that is responsible for manageing
// transactions in `InTransaction`
SetTransactionManager(tm TransactionManager)
}
func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
ctxWithTran, err := b.transactionManager.Begin(ctx)
if err != nil {
return err
}
err = fn(ctxWithTran)
b.transactionManager.End(ctxWithTran, err)
return err
} }
type InProcBus struct { type InProcBus struct {
handlers map[string]HandlerFunc handlers map[string]HandlerFunc
listeners map[string][]HandlerFunc listeners map[string][]HandlerFunc
wildcardListeners []HandlerFunc wildcardListeners []HandlerFunc
transactionManager TransactionManager
} }
// temp stuff, not sure how to handle bus instance, and init yet // temp stuff, not sure how to handle bus instance, and init yet
@ -37,6 +67,9 @@ func New() Bus {
bus.handlers = make(map[string]HandlerFunc) bus.handlers = make(map[string]HandlerFunc)
bus.listeners = make(map[string][]HandlerFunc) bus.listeners = make(map[string][]HandlerFunc)
bus.wildcardListeners = make([]HandlerFunc, 0) bus.wildcardListeners = make([]HandlerFunc, 0)
bus.transactionManager = &NoopTransactionManager{}
return bus return bus
} }
@ -45,6 +78,14 @@ func GetBus() Bus {
return globalBus return globalBus
} }
func SetTransactionManager(tm TransactionManager) {
globalBus.SetTransactionManager(tm)
}
func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
b.transactionManager = tm
}
func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error { func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
var msgName = reflect.TypeOf(msg).Elem().Name() var msgName = reflect.TypeOf(msg).Elem().Name()
@ -167,6 +208,15 @@ func Publish(msg Msg) error {
return globalBus.Publish(msg) return globalBus.Publish(msg)
} }
func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
return globalBus.InTransaction(ctx, fn)
}
func ClearBusHandlers() { func ClearBusHandlers() {
globalBus = New() globalBus = New()
} }
type NoopTransactionManager struct{}
func (*NoopTransactionManager) Begin(ctx context.Context) (context.Context, error) { return ctx, nil }
func (*NoopTransactionManager) End(ctx context.Context, err error) error { return err }

View File

@ -3,6 +3,7 @@ package notifiers
import ( import (
"github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/components/simplejson"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/alerting" "github.com/grafana/grafana/pkg/services/alerting"
) )

View File

@ -1,6 +1,7 @@
package sqlstore package sqlstore
import ( import (
"context"
"reflect" "reflect"
"time" "time"
@ -29,10 +30,35 @@ func inTransaction(callback dbTransactionFunc) error {
return inTransactionWithRetry(callback, 0) return inTransactionWithRetry(callback, 0)
} }
func startSession(ctx context.Context) *DBSession {
value := ctx.Value(ContextSessionName)
var sess *xorm.Session
sess, ok := value.(*xorm.Session)
if !ok {
return newSession()
}
old := newSession()
old.Session = sess
return old
}
func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
sess := startSession(ctx)
return callback(sess)
}
func inTransactionWithRetry(callback dbTransactionFunc, retry int) error { func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
return inTransactionWithRetryCtx(context.Background(), callback, retry)
}
func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error {
var err error var err error
sess := newSession() sess := startSession(ctx)
defer sess.Close() defer sess.Close()
if err = sess.Begin(); err != nil { if err = sess.Begin(); err != nil {

View File

@ -1,6 +1,8 @@
package sqlstore package sqlstore
import ( import (
"context"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
@ -35,6 +37,8 @@ var (
sqlog log.Logger = log.New("sqlstore") sqlog log.Logger = log.New("sqlstore")
) )
const ContextSessionName = "db-session"
func init() { func init() {
registry.Register(&registry.Descriptor{ registry.Register(&registry.Descriptor{
Name: "SqlStore", Name: "SqlStore",
@ -45,6 +49,7 @@ func init() {
type SqlStore struct { type SqlStore struct {
Cfg *setting.Cfg `inject:""` Cfg *setting.Cfg `inject:""`
Bus bus.Bus `inject:""`
dbCfg DatabaseConfig dbCfg DatabaseConfig
engine *xorm.Engine engine *xorm.Engine
@ -77,6 +82,10 @@ func (ss *SqlStore) Init() error {
// Init repo instances // Init repo instances
annotations.SetRepository(&SqlAnnotationRepo{}) annotations.SetRepository(&SqlAnnotationRepo{})
ss.Bus.SetTransactionManager(&SQLTransactionManager{
engine: ss.engine,
})
// ensure admin user // ensure admin user
if ss.skipEnsureAdmin { if ss.skipEnsureAdmin {
return nil return nil
@ -85,10 +94,47 @@ func (ss *SqlStore) Init() error {
return ss.ensureAdminUser() return ss.ensureAdminUser()
} }
type SQLTransactionManager struct {
engine *xorm.Engine
}
func (stm *SQLTransactionManager) Begin(ctx context.Context) (context.Context, error) {
sess := stm.engine.NewSession()
err := sess.Begin()
if err != nil {
return ctx, err
}
withValue := context.WithValue(ctx, ContextSessionName, sess)
return withValue, nil
}
func (stm *SQLTransactionManager) End(ctx context.Context, err error) error {
value := ctx.Value(ContextSessionName)
sess, ok := value.(*xorm.Session)
if !ok {
return errors.New("context is missing transaction")
}
if err != nil {
sess.Rollback()
return err
}
defer sess.Close()
return sess.Commit()
}
func (ss *SqlStore) ensureAdminUser() error { func (ss *SqlStore) ensureAdminUser() error {
systemUserCountQuery := m.GetSystemUserCountStatsQuery{} systemUserCountQuery := m.GetSystemUserCountStatsQuery{}
if err := bus.Dispatch(&systemUserCountQuery); err != nil { err := bus.InTransaction(context.Background(), func(ctx context.Context) error {
return bus.DispatchCtx(ctx, &systemUserCountQuery)
})
if err != nil {
return fmt.Errorf("Could not determine if admin user exists: %v", err) return fmt.Errorf("Could not determine if admin user exists: %v", err)
} }
@ -240,6 +286,7 @@ func (ss *SqlStore) readConfig() {
func InitTestDB(t *testing.T) *SqlStore { func InitTestDB(t *testing.T) *SqlStore {
sqlstore := &SqlStore{} sqlstore := &SqlStore{}
sqlstore.skipEnsureAdmin = true sqlstore.skipEnsureAdmin = true
sqlstore.Bus = bus.New()
dbType := migrator.SQLITE dbType := migrator.SQLITE

View File

@ -1,6 +1,7 @@
package sqlstore package sqlstore
import ( import (
"context"
"time" "time"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -13,6 +14,7 @@ func init() {
bus.AddHandler("sql", GetDataSourceAccessStats) bus.AddHandler("sql", GetDataSourceAccessStats)
bus.AddHandler("sql", GetAdminStats) bus.AddHandler("sql", GetAdminStats)
bus.AddHandler("sql", GetSystemUserCountStats) bus.AddHandler("sql", GetSystemUserCountStats)
bus.AddCtxHandler("sql", GetSystemUserCountStatsCtx)
} }
var activeUserTimeLimit = time.Hour * 24 * 30 var activeUserTimeLimit = time.Hour * 24 * 30
@ -133,6 +135,22 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error {
return err return err
} }
func GetSystemUserCountStatsCtx(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error {
return withDbSession(ctx, func(sess *DBSession) error {
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
var stats m.SystemUserCountStats
_, err := sess.SQL(rawSql).Get(&stats)
if err != nil {
return err
}
query.Result = &stats
return err
})
}
func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error { func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error {
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user") var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
var stats m.SystemUserCountStats var stats m.SystemUserCountStats