diff --git a/app/plugin_api_tests/test_db_driver/main.go b/app/plugin_api_tests/test_db_driver/main.go index 6c39981b27..4669e2b661 100644 --- a/app/plugin_api_tests/test_db_driver/main.go +++ b/app/plugin_api_tests/test_db_driver/main.go @@ -29,7 +29,9 @@ func (p *MyPlugin) OnConfigurationChange() error { } func (p *MyPlugin) MessageWillBePosted(_ *plugin.Context, _ *model.Post) (*model.Post, string) { - store := sqlstore.New(p.API.GetUnsanitizedConfig().SqlSettings, nil) + settings := p.API.GetUnsanitizedConfig().SqlSettings + settings.Trace = model.NewBool(false) + store := sqlstore.New(settings, nil) store.GetMaster().Db.Close() for _, isMaster := range []bool{true, false} { diff --git a/store/sqlstore/sqlx_wrapper.go b/store/sqlstore/sqlx_wrapper.go new file mode 100644 index 0000000000..b91110b19e --- /dev/null +++ b/store/sqlstore/sqlx_wrapper.go @@ -0,0 +1,287 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "context" + "database/sql" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + "github.com/jmoiron/sqlx" + + "github.com/mattermost/mattermost-server/v6/model" + "github.com/mattermost/mattermost-server/v6/shared/mlog" +) + +// namedParamRegex is used to capture all named parameters and convert them +// to lowercase. This is necessary to be able to use a single query for both +// Postgres and MySQL. +// This will also lowercase any constant strings containing a :, but sqlx +// will fail the query, so it won't be checked in inadvertently. +var namedParamRegex = regexp.MustCompile(`:\w+`) + +type sqlxDBWrapper struct { + *sqlx.DB + queryTimeout time.Duration + trace bool +} + +func newSqlxDBWrapper(db *sqlx.DB, timeout time.Duration, trace bool) *sqlxDBWrapper { + return &sqlxDBWrapper{ + DB: db, + queryTimeout: timeout, + trace: trace, + } +} + +func (w *sqlxDBWrapper) Beginx() (*sqlxTxWrapper, error) { + tx, err := w.DB.Beginx() + if err != nil { + return nil, err + } + + return newSqlxTxWrapper(tx, w.queryTimeout, w.trace), nil +} + +func (w *sqlxDBWrapper) Get(dest interface{}, query string, args ...interface{}) error { + query = w.DB.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.DB.GetContext(ctx, dest, query, args...) +} + +func (w *sqlxDBWrapper) NamedExec(query string, arg interface{}) (sql.Result, error) { + if w.DB.DriverName() == model.DatabaseDriverPostgres { + query = namedParamRegex.ReplaceAllStringFunc(query, strings.ToLower) + } + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), arg) + }(time.Now()) + } + + return w.DB.NamedExecContext(ctx, query, arg) +} + +func (w *sqlxDBWrapper) NamedQuery(query string, arg interface{}) (*sqlx.Rows, error) { + if w.DB.DriverName() == model.DatabaseDriverPostgres { + query = namedParamRegex.ReplaceAllStringFunc(query, strings.ToLower) + } + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), arg) + }(time.Now()) + } + + return w.DB.NamedQueryContext(ctx, query, arg) +} + +func (w *sqlxDBWrapper) QueryRowX(query string, args ...interface{}) *sqlx.Row { + query = w.DB.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.DB.QueryRowxContext(ctx, query, args...) +} + +func (w *sqlxDBWrapper) QueryX(query string, args ...interface{}) (*sqlx.Rows, error) { + query = w.DB.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.DB.QueryxContext(ctx, query, args) +} + +func (w *sqlxDBWrapper) Select(dest interface{}, query string, args ...interface{}) error { + query = w.DB.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.DB.SelectContext(ctx, dest, query, args...) +} + +type sqlxTxWrapper struct { + *sqlx.Tx + queryTimeout time.Duration + trace bool +} + +func newSqlxTxWrapper(tx *sqlx.Tx, timeout time.Duration, trace bool) *sqlxTxWrapper { + return &sqlxTxWrapper{ + Tx: tx, + queryTimeout: timeout, + trace: trace, + } +} + +func (w *sqlxTxWrapper) Get(dest interface{}, query string, args ...interface{}) error { + query = w.Tx.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.Tx.GetContext(ctx, dest, query, args...) +} + +func (w *sqlxTxWrapper) NamedExec(query string, arg interface{}) (sql.Result, error) { + if w.Tx.DriverName() == model.DatabaseDriverPostgres { + query = namedParamRegex.ReplaceAllStringFunc(query, strings.ToLower) + } + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), arg) + }(time.Now()) + } + + return w.Tx.NamedExecContext(ctx, query, arg) +} + +func (w *sqlxTxWrapper) NamedQuery(query string, arg interface{}) (*sqlx.Rows, error) { + if w.Tx.DriverName() == model.DatabaseDriverPostgres { + query = namedParamRegex.ReplaceAllStringFunc(query, strings.ToLower) + } + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), arg) + }(time.Now()) + } + + // There is no tx.NamedQueryContext support in the sqlx API. (https://github.com/jmoiron/sqlx/issues/447) + // So we need to implement this ourselves. + type result struct { + rows *sqlx.Rows + err error + } + + // Need to add a buffer of 1 to prevent goroutine leak. + resChan := make(chan *result, 1) + go func() { + rows, err := w.Tx.NamedQuery(query, arg) + resChan <- &result{ + rows: rows, + err: err, + } + }() + + // staticcheck fails to check that res gets re-assigned later. + res := &result{} //nolint:staticcheck + select { + case res = <-resChan: + case <-ctx.Done(): + res = &result{ + rows: nil, + err: ctx.Err(), + } + } + + return res.rows, res.err +} + +func (w *sqlxTxWrapper) QueryRowX(query string, args ...interface{}) *sqlx.Row { + query = w.Tx.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.Tx.QueryRowxContext(ctx, query, args...) +} + +func (w *sqlxTxWrapper) QueryX(query string, args ...interface{}) (*sqlx.Rows, error) { + query = w.Tx.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.Tx.QueryxContext(ctx, query, args) +} + +func (w *sqlxTxWrapper) Select(dest interface{}, query string, args ...interface{}) error { + query = w.Tx.Rebind(query) + ctx, cancel := context.WithTimeout(context.Background(), w.queryTimeout) + defer cancel() + + if w.trace { + defer func(then time.Time) { + printArgs(query, time.Since(then), args) + }(time.Now()) + } + + return w.Tx.SelectContext(ctx, dest, query, args...) +} + +func removeSpace(r rune) rune { + // Strip everything except ' ' + // This also strips out more than one space, + // but we ignore it for now until someone complains. + if unicode.IsSpace(r) && r != ' ' { + return -1 + } + return r +} + +func printArgs(query string, dur time.Duration, args ...interface{}) { + query = strings.Map(removeSpace, query) + fields := make([]mlog.Field, 0, len(args)+1) + fields = append(fields, mlog.Duration("duration", dur)) + for i, arg := range args { + fields = append(fields, mlog.Any("arg"+strconv.Itoa(i), arg)) + } + mlog.Debug(query, fields...) +} diff --git a/store/sqlstore/sqlx_wrapper_test.go b/store/sqlstore/sqlx_wrapper_test.go new file mode 100644 index 0000000000..c9c5ea18f6 --- /dev/null +++ b/store/sqlstore/sqlx_wrapper_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "context" + "strings" + "testing" + + "github.com/mattermost/mattermost-server/v6/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSqlX(t *testing.T) { + t.Run("NamedQuery", func(t *testing.T) { + testDrivers := []string{ + model.DatabaseDriverPostgres, + model.DatabaseDriverMysql, + } + + for _, driver := range testDrivers { + settings := makeSqlSettings(driver) + *settings.QueryTimeout = 1 + store := &SqlStore{ + rrCounter: 0, + srCounter: 0, + settings: settings, + } + + store.initConnection() + + defer store.Close() + + tx, err := store.GetMasterX().Beginx() + require.NoError(t, err) + + var query string + if store.DriverName() == model.DatabaseDriverMysql { + query = `SELECT SLEEP(:Timeout);` + } else if store.DriverName() == model.DatabaseDriverPostgres { + query = `SELECT pg_sleep(:timeout);` + } + arg := struct{ Timeout int }{Timeout: 2} + _, err = tx.NamedQuery(query, arg) + require.Equal(t, context.DeadlineExceeded, err) + require.NoError(t, tx.Commit()) + } + }) + + t.Run("NamedParse", func(t *testing.T) { + queries := []struct { + in string + out string + }{ + { + in: `SELECT pg_sleep(:Timeout)`, + out: `SELECT pg_sleep(:timeout)`, + }, + { + in: `SELECT u.Username FROM Bots + LIMIT + :Limit + OFFSET + :Offset`, + out: `SELECT u.Username FROM Bots + LIMIT + :limit + OFFSET + :offset`, + }, + { + in: `UPDATE OAuthAccessData SET Token =:Token`, + out: `UPDATE OAuthAccessData SET Token =:token`, + }, + } + for _, q := range queries { + out := namedParamRegex.ReplaceAllStringFunc(q.in, strings.ToLower) + assert.Equal(t, q.out, out) + } + }) +} diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index badba9251a..19a59345aa 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -142,13 +142,13 @@ type SqlStore struct { srCounter int64 master *gorp.DbMap - masterX *sqlx.DB + masterX *sqlxDBWrapper Replicas []*gorp.DbMap - ReplicaXs []*sqlx.DB + ReplicaXs []*sqlxDBWrapper searchReplicas []*gorp.DbMap - searchReplicaXs []*sqlx.DB + searchReplicaXs []*sqlxDBWrapper replicaLagHandles []*dbsql.DB stores SqlStoreStores @@ -389,18 +389,22 @@ func (ss *SqlStore) initConnection() { handle := setupConnection("master", dataSource, ss.settings) ss.master = getDBMap(ss.settings, handle) - ss.masterX = sqlx.NewDb(handle, ss.DriverName()) + ss.masterX = newSqlxDBWrapper(sqlx.NewDb(handle, ss.DriverName()), + time.Duration(*ss.settings.QueryTimeout)*time.Second, + *ss.settings.Trace) if ss.DriverName() == model.DatabaseDriverMysql { ss.masterX.MapperFunc(noOpMapper) } if len(ss.settings.DataSourceReplicas) > 0 { ss.Replicas = make([]*gorp.DbMap, len(ss.settings.DataSourceReplicas)) - ss.ReplicaXs = make([]*sqlx.DB, len(ss.settings.DataSourceReplicas)) + ss.ReplicaXs = make([]*sqlxDBWrapper, len(ss.settings.DataSourceReplicas)) for i, replica := range ss.settings.DataSourceReplicas { handle := setupConnection(fmt.Sprintf("replica-%v", i), replica, ss.settings) ss.Replicas[i] = getDBMap(ss.settings, handle) - ss.ReplicaXs[i] = sqlx.NewDb(handle, ss.DriverName()) + ss.ReplicaXs[i] = newSqlxDBWrapper(sqlx.NewDb(handle, ss.DriverName()), + time.Duration(*ss.settings.QueryTimeout)*time.Second, + *ss.settings.Trace) if ss.DriverName() == model.DatabaseDriverMysql { ss.ReplicaXs[i].MapperFunc(noOpMapper) } @@ -409,11 +413,13 @@ func (ss *SqlStore) initConnection() { if len(ss.settings.DataSourceSearchReplicas) > 0 { ss.searchReplicas = make([]*gorp.DbMap, len(ss.settings.DataSourceSearchReplicas)) - ss.searchReplicaXs = make([]*sqlx.DB, len(ss.settings.DataSourceSearchReplicas)) + ss.searchReplicaXs = make([]*sqlxDBWrapper, len(ss.settings.DataSourceSearchReplicas)) for i, replica := range ss.settings.DataSourceSearchReplicas { handle := setupConnection(fmt.Sprintf("search-replica-%v", i), replica, ss.settings) ss.searchReplicas[i] = getDBMap(ss.settings, handle) - ss.searchReplicaXs[i] = sqlx.NewDb(handle, ss.DriverName()) + ss.searchReplicaXs[i] = newSqlxDBWrapper(sqlx.NewDb(handle, ss.DriverName()), + time.Duration(*ss.settings.QueryTimeout)*time.Second, + *ss.settings.Trace) if ss.DriverName() == model.DatabaseDriverMysql { ss.searchReplicaXs[i].MapperFunc(noOpMapper) } @@ -470,7 +476,7 @@ func (ss *SqlStore) GetMaster() *gorp.DbMap { return ss.master } -func (ss *SqlStore) GetMasterX() *sqlx.DB { +func (ss *SqlStore) GetMasterX() *sqlxDBWrapper { return ss.masterX } @@ -502,7 +508,7 @@ func (ss *SqlStore) GetReplica() *gorp.DbMap { return ss.Replicas[rrNum] } -func (ss *SqlStore) GetReplicaX() *sqlx.DB { +func (ss *SqlStore) GetReplicaX() *sqlxDBWrapper { ss.licenseMutex.RLock() license := ss.license ss.licenseMutex.RUnlock()