mirror of
https://github.com/grafana/grafana.git
synced 2025-01-23 15:03:41 -06:00
Merge pull request #12203 from bergquist/bus_multi_dispatch
bus: support multiple dispatch in one transaction
This commit is contained in:
commit
d6f4313c2f
@ -12,21 +12,42 @@ type Msg interface{}
|
||||
|
||||
var ErrHandlerNotFound = errors.New("handler not found")
|
||||
|
||||
type TransactionManager interface {
|
||||
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
|
||||
}
|
||||
|
||||
type Bus interface {
|
||||
Dispatch(msg Msg) error
|
||||
DispatchCtx(ctx context.Context, msg Msg) error
|
||||
Publish(msg Msg) error
|
||||
|
||||
// InTransaction starts a transaction and store it in the context.
|
||||
// The caller can then pass a function with multiple DispatchCtx calls that
|
||||
// all will be executed in the same transaction. InTransaction will rollback if the
|
||||
// callback returns an error.
|
||||
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
|
||||
|
||||
AddHandler(handler HandlerFunc)
|
||||
AddCtxHandler(handler HandlerFunc)
|
||||
AddHandlerCtx(handler HandlerFunc)
|
||||
AddEventListener(handler HandlerFunc)
|
||||
AddWildcardListener(handler HandlerFunc)
|
||||
|
||||
// SetTransactionManager allows the user to replace the internal
|
||||
// noop TransactionManager that is responsible for manageing
|
||||
// transactions in `InTransaction`
|
||||
SetTransactionManager(tm TransactionManager)
|
||||
}
|
||||
|
||||
func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return b.txMng.InTransaction(ctx, fn)
|
||||
}
|
||||
|
||||
type InProcBus struct {
|
||||
handlers map[string]HandlerFunc
|
||||
handlersWithCtx map[string]HandlerFunc
|
||||
listeners map[string][]HandlerFunc
|
||||
wildcardListeners []HandlerFunc
|
||||
txMng TransactionManager
|
||||
}
|
||||
|
||||
// temp stuff, not sure how to handle bus instance, and init yet
|
||||
@ -35,8 +56,11 @@ var globalBus = New()
|
||||
func New() Bus {
|
||||
bus := &InProcBus{}
|
||||
bus.handlers = make(map[string]HandlerFunc)
|
||||
bus.handlersWithCtx = make(map[string]HandlerFunc)
|
||||
bus.listeners = make(map[string][]HandlerFunc)
|
||||
bus.wildcardListeners = make([]HandlerFunc, 0)
|
||||
bus.txMng = &noopTransactionManager{}
|
||||
|
||||
return bus
|
||||
}
|
||||
|
||||
@ -45,17 +69,21 @@ func GetBus() Bus {
|
||||
return globalBus
|
||||
}
|
||||
|
||||
func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
|
||||
b.txMng = tm
|
||||
}
|
||||
|
||||
func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
|
||||
var msgName = reflect.TypeOf(msg).Elem().Name()
|
||||
|
||||
var handler = b.handlers[msgName]
|
||||
var handler = b.handlersWithCtx[msgName]
|
||||
if handler == nil {
|
||||
return ErrHandlerNotFound
|
||||
}
|
||||
|
||||
var params = make([]reflect.Value, 2)
|
||||
params[0] = reflect.ValueOf(ctx)
|
||||
params[1] = reflect.ValueOf(msg)
|
||||
var params = []reflect.Value{}
|
||||
params = append(params, reflect.ValueOf(ctx))
|
||||
params = append(params, reflect.ValueOf(msg))
|
||||
|
||||
ret := reflect.ValueOf(handler).Call(params)
|
||||
err := ret[0].Interface()
|
||||
@ -68,13 +96,23 @@ func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
|
||||
func (b *InProcBus) Dispatch(msg Msg) error {
|
||||
var msgName = reflect.TypeOf(msg).Elem().Name()
|
||||
|
||||
var handler = b.handlers[msgName]
|
||||
var handler = b.handlersWithCtx[msgName]
|
||||
withCtx := true
|
||||
|
||||
if handler == nil {
|
||||
withCtx = false
|
||||
handler = b.handlers[msgName]
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
return ErrHandlerNotFound
|
||||
}
|
||||
|
||||
var params = make([]reflect.Value, 1)
|
||||
params[0] = reflect.ValueOf(msg)
|
||||
var params = []reflect.Value{}
|
||||
if withCtx {
|
||||
params = append(params, reflect.ValueOf(context.Background()))
|
||||
}
|
||||
params = append(params, reflect.ValueOf(msg))
|
||||
|
||||
ret := reflect.ValueOf(handler).Call(params)
|
||||
err := ret[0].Interface()
|
||||
@ -120,10 +158,10 @@ func (b *InProcBus) AddHandler(handler HandlerFunc) {
|
||||
b.handlers[queryTypeName] = handler
|
||||
}
|
||||
|
||||
func (b *InProcBus) AddCtxHandler(handler HandlerFunc) {
|
||||
func (b *InProcBus) AddHandlerCtx(handler HandlerFunc) {
|
||||
handlerType := reflect.TypeOf(handler)
|
||||
queryTypeName := handlerType.In(1).Elem().Name()
|
||||
b.handlers[queryTypeName] = handler
|
||||
b.handlersWithCtx[queryTypeName] = handler
|
||||
}
|
||||
|
||||
func (b *InProcBus) AddEventListener(handler HandlerFunc) {
|
||||
@ -142,8 +180,8 @@ func AddHandler(implName string, handler HandlerFunc) {
|
||||
}
|
||||
|
||||
// Package level functions
|
||||
func AddCtxHandler(implName string, handler HandlerFunc) {
|
||||
globalBus.AddCtxHandler(handler)
|
||||
func AddHandlerCtx(implName string, handler HandlerFunc) {
|
||||
globalBus.AddHandlerCtx(handler)
|
||||
}
|
||||
|
||||
// Package level functions
|
||||
@ -167,6 +205,20 @@ func Publish(msg Msg) error {
|
||||
return globalBus.Publish(msg)
|
||||
}
|
||||
|
||||
// InTransaction starts a transaction and store it in the context.
|
||||
// The caller can then pass a function with multiple DispatchCtx calls that
|
||||
// all will be executed in the same transaction. InTransaction will rollback if the
|
||||
// callback returns an error.
|
||||
func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return globalBus.InTransaction(ctx, fn)
|
||||
}
|
||||
|
||||
func ClearBusHandlers() {
|
||||
globalBus = New()
|
||||
}
|
||||
|
||||
type noopTransactionManager struct{}
|
||||
|
||||
func (*noopTransactionManager) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
@ -1,24 +1,67 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestQuery struct {
|
||||
type testQuery struct {
|
||||
Id int64
|
||||
Resp string
|
||||
}
|
||||
|
||||
func TestDispatchCtxCanUseNormalHandlers(t *testing.T) {
|
||||
bus := New()
|
||||
|
||||
handlerWithCtxCallCount := 0
|
||||
handlerCallCount := 0
|
||||
|
||||
handlerWithCtx := func(ctx context.Context, query *testQuery) error {
|
||||
handlerWithCtxCallCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
handler := func(query *testQuery) error {
|
||||
handlerCallCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := bus.DispatchCtx(context.Background(), &testQuery{})
|
||||
if err != ErrHandlerNotFound {
|
||||
t.Errorf("expected bus to return HandlerNotFound is no handler is registered")
|
||||
}
|
||||
|
||||
bus.AddHandler(handler)
|
||||
|
||||
t.Run("when a normal handler is registered", func(t *testing.T) {
|
||||
bus.Dispatch(&testQuery{})
|
||||
|
||||
if handlerCallCount != 1 {
|
||||
t.Errorf("Expected normal handler to be called 1 time. was called %d", handlerCallCount)
|
||||
}
|
||||
|
||||
t.Run("when a ctx handler is registered", func(t *testing.T) {
|
||||
bus.AddHandlerCtx(handlerWithCtx)
|
||||
bus.Dispatch(&testQuery{})
|
||||
|
||||
if handlerWithCtxCallCount != 1 {
|
||||
t.Errorf("Expected ctx handler to be called 1 time. was called %d", handlerWithCtxCallCount)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestQueryHandlerReturnsError(t *testing.T) {
|
||||
bus := New()
|
||||
|
||||
bus.AddHandler(func(query *TestQuery) error {
|
||||
bus.AddHandler(func(query *testQuery) error {
|
||||
return errors.New("handler error")
|
||||
})
|
||||
|
||||
err := bus.Dispatch(&TestQuery{})
|
||||
err := bus.Dispatch(&testQuery{})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Send query failed " + err.Error())
|
||||
@ -30,12 +73,12 @@ func TestQueryHandlerReturnsError(t *testing.T) {
|
||||
func TestQueryHandlerReturn(t *testing.T) {
|
||||
bus := New()
|
||||
|
||||
bus.AddHandler(func(q *TestQuery) error {
|
||||
bus.AddHandler(func(q *testQuery) error {
|
||||
q.Resp = "hello from handler"
|
||||
return nil
|
||||
})
|
||||
|
||||
query := &TestQuery{}
|
||||
query := &testQuery{}
|
||||
err := bus.Dispatch(query)
|
||||
|
||||
if err != nil {
|
||||
@ -49,17 +92,17 @@ func TestEventListeners(t *testing.T) {
|
||||
bus := New()
|
||||
count := 0
|
||||
|
||||
bus.AddEventListener(func(query *TestQuery) error {
|
||||
bus.AddEventListener(func(query *testQuery) error {
|
||||
count += 1
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddEventListener(func(query *TestQuery) error {
|
||||
bus.AddEventListener(func(query *testQuery) error {
|
||||
count += 10
|
||||
return nil
|
||||
})
|
||||
|
||||
err := bus.Publish(&TestQuery{})
|
||||
err := bus.Publish(&testQuery{})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Publish event failed " + err.Error())
|
||||
|
@ -3,6 +3,7 @@ package notifiers
|
||||
import (
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/alerting"
|
||||
)
|
||||
|
||||
|
@ -45,8 +45,8 @@ func (ns *NotificationService) Init() error {
|
||||
ns.Bus.AddHandler(ns.validateResetPasswordCode)
|
||||
ns.Bus.AddHandler(ns.sendEmailCommandHandler)
|
||||
|
||||
ns.Bus.AddCtxHandler(ns.sendEmailCommandHandlerSync)
|
||||
ns.Bus.AddCtxHandler(ns.SendWebhookSync)
|
||||
ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync)
|
||||
ns.Bus.AddHandlerCtx(ns.SendWebhookSync)
|
||||
|
||||
ns.Bus.AddEventListener(ns.signUpStartedHandler)
|
||||
ns.Bus.AddEventListener(ns.signUpCompletedHandler)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -11,7 +12,7 @@ func init() {
|
||||
bus.AddHandler("sql", GetApiKeys)
|
||||
bus.AddHandler("sql", GetApiKeyById)
|
||||
bus.AddHandler("sql", GetApiKeyByName)
|
||||
bus.AddHandler("sql", DeleteApiKey)
|
||||
bus.AddHandlerCtx("sql", DeleteApiKeyCtx)
|
||||
bus.AddHandler("sql", AddApiKey)
|
||||
}
|
||||
|
||||
@ -22,8 +23,8 @@ func GetApiKeys(query *m.GetApiKeysQuery) error {
|
||||
return sess.Find(&query.Result)
|
||||
}
|
||||
|
||||
func DeleteApiKey(cmd *m.DeleteApiKeyCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
func DeleteApiKeyCtx(ctx context.Context, cmd *m.DeleteApiKeyCommand) error {
|
||||
return withDbSession(ctx, func(sess *DBSession) error {
|
||||
var rawSql = "DELETE FROM api_key WHERE id=? and org_id=?"
|
||||
_, err := sess.Exec(rawSql, cmd.Id, cmd.OrgId)
|
||||
return err
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -389,7 +390,7 @@ func createUser(name string, role string, isAdmin bool) m.User {
|
||||
setting.AutoAssignOrgRole = role
|
||||
|
||||
currentUserCmd := m.CreateUserCommand{Login: name, Email: name + "@test.com", Name: "a " + name, IsAdmin: isAdmin}
|
||||
err := CreateUser(¤tUserCmd)
|
||||
err := CreateUser(context.Background(), ¤tUserCmd)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
q1 := m.GetUserOrgListQuery{UserId: currentUserCmd.Result.Id}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -22,9 +23,9 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
ac1cmd := m.CreateUserCommand{Login: "ac1", Email: "ac1@test.com", Name: "ac1 name"}
|
||||
ac2cmd := m.CreateUserCommand{Login: "ac2", Email: "ac2@test.com", Name: "ac2 name"}
|
||||
|
||||
err := CreateUser(&ac1cmd)
|
||||
err := CreateUser(context.Background(), &ac1cmd)
|
||||
So(err, ShouldBeNil)
|
||||
err = CreateUser(&ac2cmd)
|
||||
err = CreateUser(context.Background(), &ac2cmd)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
q1 := m.GetUserOrgListQuery{UserId: ac1cmd.Result.Id}
|
||||
@ -43,8 +44,8 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
ac1cmd := m.CreateUserCommand{Login: "ac1", Email: "ac1@test.com", Name: "ac1 name"}
|
||||
ac2cmd := m.CreateUserCommand{Login: "ac2", Email: "ac2@test.com", Name: "ac2 name", IsAdmin: true}
|
||||
|
||||
err := CreateUser(&ac1cmd)
|
||||
err = CreateUser(&ac2cmd)
|
||||
err := CreateUser(context.Background(), &ac1cmd)
|
||||
err = CreateUser(context.Background(), &ac2cmd)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
ac1 := ac1cmd.Result
|
||||
@ -182,7 +183,7 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
|
||||
Convey("Given an org user with dashboard permissions", func() {
|
||||
ac3cmd := m.CreateUserCommand{Login: "ac3", Email: "ac3@test.com", Name: "ac3 name", IsAdmin: false}
|
||||
err := CreateUser(&ac3cmd)
|
||||
err := CreateUser(context.Background(), &ac3cmd)
|
||||
So(err, ShouldBeNil)
|
||||
ac3 := ac3cmd.Result
|
||||
|
||||
|
71
pkg/services/sqlstore/session.go
Normal file
71
pkg/services/sqlstore/session.go
Normal file
@ -0,0 +1,71 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/go-xorm/xorm"
|
||||
)
|
||||
|
||||
type DBSession struct {
|
||||
*xorm.Session
|
||||
events []interface{}
|
||||
}
|
||||
|
||||
type dbTransactionFunc func(sess *DBSession) error
|
||||
|
||||
func (sess *DBSession) publishAfterCommit(msg interface{}) {
|
||||
sess.events = append(sess.events, msg)
|
||||
}
|
||||
|
||||
func newSession() *DBSession {
|
||||
return &DBSession{Session: x.NewSession()}
|
||||
}
|
||||
|
||||
func startSession(ctx context.Context, engine *xorm.Engine, beginTran bool) (*DBSession, error) {
|
||||
value := ctx.Value(ContextSessionName)
|
||||
var sess *DBSession
|
||||
sess, ok := value.(*DBSession)
|
||||
|
||||
if !ok {
|
||||
newSess := &DBSession{Session: engine.NewSession()}
|
||||
if beginTran {
|
||||
err := newSess.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return newSess, nil
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
|
||||
sess, err := startSession(ctx, x, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return callback(sess)
|
||||
}
|
||||
|
||||
func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
|
||||
table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
|
||||
|
||||
dialect.PreInsertId(table, sess.Session)
|
||||
|
||||
id, err := sess.Session.InsertOne(bean)
|
||||
|
||||
dialect.PostInsertId(table, sess.Session)
|
||||
|
||||
return id, err
|
||||
}
|
||||
|
||||
func getTypeName(bean interface{}) (res string) {
|
||||
t := reflect.TypeOf(bean)
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t.Name()
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/go-xorm/xorm"
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/log"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type DBSession struct {
|
||||
*xorm.Session
|
||||
events []interface{}
|
||||
}
|
||||
|
||||
type dbTransactionFunc func(sess *DBSession) error
|
||||
|
||||
func (sess *DBSession) publishAfterCommit(msg interface{}) {
|
||||
sess.events = append(sess.events, msg)
|
||||
}
|
||||
|
||||
func newSession() *DBSession {
|
||||
return &DBSession{Session: x.NewSession()}
|
||||
}
|
||||
|
||||
func inTransaction(callback dbTransactionFunc) error {
|
||||
return inTransactionWithRetry(callback, 0)
|
||||
}
|
||||
|
||||
func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
|
||||
var err error
|
||||
|
||||
sess := newSession()
|
||||
defer sess.Close()
|
||||
|
||||
if err = sess.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = callback(sess)
|
||||
|
||||
// special handling of database locked errors for sqlite, then we can retry 3 times
|
||||
if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
|
||||
if sqlError.Code == sqlite3.ErrLocked {
|
||||
sess.Rollback()
|
||||
time.Sleep(time.Millisecond * time.Duration(10))
|
||||
sqlog.Info("Database table locked, sleeping then retrying", "retry", retry)
|
||||
return inTransactionWithRetry(callback, retry+1)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
sess.Rollback()
|
||||
return err
|
||||
} else if err = sess.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(sess.events) > 0 {
|
||||
for _, e := range sess.events {
|
||||
if err = bus.Publish(e); err != nil {
|
||||
log.Error(3, "Failed to publish event after commit", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
|
||||
table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
|
||||
|
||||
dialect.PreInsertId(table, sess.Session)
|
||||
|
||||
id, err := sess.Session.InsertOne(bean)
|
||||
|
||||
dialect.PostInsertId(table, sess.Session)
|
||||
|
||||
return id, err
|
||||
}
|
||||
|
||||
func getTypeName(bean interface{}) (res string) {
|
||||
t := reflect.TypeOf(bean)
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t.Name()
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -22,10 +23,10 @@ import (
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/go-xorm/xorm"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
_ "github.com/grafana/grafana/pkg/tsdb/mssql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -35,6 +36,8 @@ var (
|
||||
sqlog log.Logger = log.New("sqlstore")
|
||||
)
|
||||
|
||||
const ContextSessionName = "db-session"
|
||||
|
||||
func init() {
|
||||
registry.Register(®istry.Descriptor{
|
||||
Name: "SqlStore",
|
||||
@ -45,6 +48,7 @@ func init() {
|
||||
|
||||
type SqlStore struct {
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
Bus bus.Bus `inject:""`
|
||||
|
||||
dbCfg DatabaseConfig
|
||||
engine *xorm.Engine
|
||||
@ -77,6 +81,8 @@ func (ss *SqlStore) Init() error {
|
||||
// Init repo instances
|
||||
annotations.SetRepository(&SqlAnnotationRepo{})
|
||||
|
||||
ss.Bus.SetTransactionManager(ss)
|
||||
|
||||
// ensure admin user
|
||||
if ss.skipEnsureAdmin {
|
||||
return nil
|
||||
@ -88,27 +94,33 @@ func (ss *SqlStore) Init() error {
|
||||
func (ss *SqlStore) ensureAdminUser() error {
|
||||
systemUserCountQuery := m.GetSystemUserCountStatsQuery{}
|
||||
|
||||
if err := bus.Dispatch(&systemUserCountQuery); err != nil {
|
||||
return fmt.Errorf("Could not determine if admin user exists: %v", err)
|
||||
}
|
||||
err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
|
||||
|
||||
err := bus.DispatchCtx(ctx, &systemUserCountQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not determine if admin user exists: %v", err)
|
||||
}
|
||||
|
||||
if systemUserCountQuery.Result.Count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := m.CreateUserCommand{}
|
||||
cmd.Login = setting.AdminUser
|
||||
cmd.Email = setting.AdminUser + "@localhost"
|
||||
cmd.Password = setting.AdminPassword
|
||||
cmd.IsAdmin = true
|
||||
|
||||
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
|
||||
return fmt.Errorf("Failed to create admin user: %v", err)
|
||||
}
|
||||
|
||||
ss.log.Info("Created default admin", "user", setting.AdminUser)
|
||||
|
||||
if systemUserCountQuery.Result.Count > 0 {
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
cmd := m.CreateUserCommand{}
|
||||
cmd.Login = setting.AdminUser
|
||||
cmd.Email = setting.AdminUser + "@localhost"
|
||||
cmd.Password = setting.AdminPassword
|
||||
cmd.IsAdmin = true
|
||||
|
||||
if err := bus.Dispatch(&cmd); err != nil {
|
||||
return fmt.Errorf("Failed to create admin user: %v", err)
|
||||
}
|
||||
|
||||
ss.log.Info("Created default admin user: %v", setting.AdminUser)
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (ss *SqlStore) buildConnectionString() (string, error) {
|
||||
@ -238,8 +250,10 @@ func (ss *SqlStore) readConfig() {
|
||||
}
|
||||
|
||||
func InitTestDB(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
sqlstore := &SqlStore{}
|
||||
sqlstore.skipEnsureAdmin = true
|
||||
sqlstore.Bus = bus.New()
|
||||
|
||||
dbType := migrator.SQLITE
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -12,7 +13,7 @@ func init() {
|
||||
bus.AddHandler("sql", GetDataSourceStats)
|
||||
bus.AddHandler("sql", GetDataSourceAccessStats)
|
||||
bus.AddHandler("sql", GetAdminStats)
|
||||
bus.AddHandler("sql", GetSystemUserCountStats)
|
||||
bus.AddHandlerCtx("sql", GetSystemUserCountStats)
|
||||
}
|
||||
|
||||
var activeUserTimeLimit = time.Hour * 24 * 30
|
||||
@ -133,15 +134,18 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error {
|
||||
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
|
||||
var stats m.SystemUserCountStats
|
||||
_, err := x.SQL(rawSql).Get(&stats)
|
||||
if err != nil {
|
||||
func GetSystemUserCountStats(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error {
|
||||
return withDbSession(ctx, func(sess *DBSession) error {
|
||||
|
||||
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
|
||||
var stats m.SystemUserCountStats
|
||||
_, err := sess.SQL(rawSql).Get(&stats)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query.Result = &stats
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
query.Result = &stats
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
@ -20,7 +21,7 @@ func TestStatsDataAccess(t *testing.T) {
|
||||
|
||||
Convey("Get system user count stats should not results in error", func() {
|
||||
query := m.GetSystemUserCountStatsQuery{}
|
||||
err := GetSystemUserCountStats(&query)
|
||||
err := GetSystemUserCountStats(context.Background(), &query)
|
||||
So(err, ShouldBeNil)
|
||||
})
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@ -22,7 +23,7 @@ func TestTeamCommandsAndQueries(t *testing.T) {
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
err := CreateUser(userCmd)
|
||||
err := CreateUser(context.Background(), userCmd)
|
||||
So(err, ShouldBeNil)
|
||||
userIds = append(userIds, userCmd.Result.Id)
|
||||
}
|
||||
|
106
pkg/services/sqlstore/transactions.go
Normal file
106
pkg/services/sqlstore/transactions.go
Normal file
@ -0,0 +1,106 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/log"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func (ss *SqlStore) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return ss.inTransactionWithRetry(ctx, fn, 0)
|
||||
}
|
||||
|
||||
func (ss *SqlStore) inTransactionWithRetry(ctx context.Context, fn func(ctx context.Context) error, retry int) error {
|
||||
sess, err := startSession(ctx, ss.engine, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
withValue := context.WithValue(ctx, ContextSessionName, sess)
|
||||
|
||||
err = fn(withValue)
|
||||
|
||||
// special handling of database locked errors for sqlite, then we can retry 3 times
|
||||
if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
|
||||
if sqlError.Code == sqlite3.ErrLocked {
|
||||
sess.Rollback()
|
||||
time.Sleep(time.Millisecond * time.Duration(10))
|
||||
ss.log.Info("Database table locked, sleeping then retrying", "retry", retry)
|
||||
return ss.inTransactionWithRetry(ctx, fn, retry+1)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
sess.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err = sess.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(sess.events) > 0 {
|
||||
for _, e := range sess.events {
|
||||
if err = bus.Publish(e); err != nil {
|
||||
ss.log.Error("Failed to publish event after commit", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
|
||||
return inTransactionWithRetryCtx(context.Background(), callback, retry)
|
||||
}
|
||||
|
||||
func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error {
|
||||
sess, err := startSession(ctx, x, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
err = callback(sess)
|
||||
|
||||
// special handling of database locked errors for sqlite, then we can retry 3 times
|
||||
if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
|
||||
if sqlError.Code == sqlite3.ErrLocked {
|
||||
sess.Rollback()
|
||||
time.Sleep(time.Millisecond * time.Duration(10))
|
||||
sqlog.Info("Database table locked, sleeping then retrying", "retry", retry)
|
||||
return inTransactionWithRetry(callback, retry+1)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
sess.Rollback()
|
||||
return err
|
||||
} else if err = sess.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(sess.events) > 0 {
|
||||
for _, e := range sess.events {
|
||||
if err = bus.Publish(e); err != nil {
|
||||
log.Error(3, "Failed to publish event after commit", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func inTransaction(callback dbTransactionFunc) error {
|
||||
return inTransactionWithRetry(callback, 0)
|
||||
}
|
||||
|
||||
func inTransactionCtx(ctx context.Context, callback dbTransactionFunc) error {
|
||||
return inTransactionWithRetryCtx(ctx, callback, 0)
|
||||
}
|
60
pkg/services/sqlstore/transactions_test.go
Normal file
60
pkg/services/sqlstore/transactions_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
type testQuery struct {
|
||||
result bool
|
||||
}
|
||||
|
||||
var ProvokedError = errors.New("testing error.")
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
ss := InitTestDB(t)
|
||||
|
||||
Convey("InTransaction asdf asdf", t, func() {
|
||||
cmd := &models.AddApiKeyCommand{Key: "secret-key", Name: "key", OrgId: 1}
|
||||
|
||||
err := AddApiKey(cmd)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
deleteApiKeyCmd := &models.DeleteApiKeyCommand{Id: cmd.Result.Id, OrgId: 1}
|
||||
|
||||
Convey("can update key", func() {
|
||||
err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
|
||||
return DeleteApiKeyCtx(ctx, deleteApiKeyCmd)
|
||||
})
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
query := &models.GetApiKeyByIdQuery{ApiKeyId: cmd.Result.Id}
|
||||
err = GetApiKeyById(query)
|
||||
So(err, ShouldEqual, models.ErrInvalidApiKey)
|
||||
})
|
||||
|
||||
Convey("wont update if one handler fails", func() {
|
||||
err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
|
||||
err := DeleteApiKeyCtx(ctx, deleteApiKeyCmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ProvokedError
|
||||
})
|
||||
|
||||
So(err, ShouldEqual, ProvokedError)
|
||||
|
||||
query := &models.GetApiKeyByIdQuery{ApiKeyId: cmd.Result.Id}
|
||||
err = GetApiKeyById(query)
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Id, ShouldEqual, cmd.Result.Id)
|
||||
})
|
||||
})
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
bus.AddHandler("sql", CreateUser)
|
||||
//bus.AddHandler("sql", CreateUser)
|
||||
bus.AddHandler("sql", GetUserById)
|
||||
bus.AddHandler("sql", UpdateUser)
|
||||
bus.AddHandler("sql", ChangeUserPassword)
|
||||
@ -30,6 +31,7 @@ func init() {
|
||||
bus.AddHandler("sql", DeleteUser)
|
||||
bus.AddHandler("sql", UpdateUserPermissions)
|
||||
bus.AddHandler("sql", SetUserHelpFlag)
|
||||
bus.AddHandlerCtx("sql", CreateUser)
|
||||
}
|
||||
|
||||
func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error) {
|
||||
@ -79,8 +81,8 @@ func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error
|
||||
return org.Id, nil
|
||||
}
|
||||
|
||||
func CreateUser(cmd *m.CreateUserCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
func CreateUser(ctx context.Context, cmd *m.CreateUserCommand) error {
|
||||
return inTransactionCtx(ctx, func(sess *DBSession) error {
|
||||
orgId, err := getOrgIdForNewUser(cmd, sess)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@ -22,7 +23,7 @@ func TestUserAuth(t *testing.T) {
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
err = CreateUser(cmd)
|
||||
err = CreateUser(context.Background(), cmd)
|
||||
So(err, ShouldBeNil)
|
||||
users = append(users, cmd.Result)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@ -24,7 +25,7 @@ func TestUserDataAccess(t *testing.T) {
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
err = CreateUser(cmd)
|
||||
err = CreateUser(context.Background(), cmd)
|
||||
So(err, ShouldBeNil)
|
||||
users = append(users, cmd.Result)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user