mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
Moved some common SQL function to public utils as they are used in plugins (#26412)
* Moved some common SQL function tu public utls as they are used in plugins * goimported file * Added tests * Created sub-package * MOved SetupConnection to public sql utils
This commit is contained in:
parent
f34445a6f4
commit
4fda7e6f34
@ -12,6 +12,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/db"
|
"github.com/mattermost/mattermost/server/v8/channels/db"
|
||||||
@ -119,16 +121,16 @@ func (ss *SqlStore) initMorph(dryRun bool) (*morph.Morph, error) {
|
|||||||
var driver drivers.Driver
|
var driver drivers.Driver
|
||||||
switch ss.DriverName() {
|
switch ss.DriverName() {
|
||||||
case model.DatabaseDriverMysql:
|
case model.DatabaseDriverMysql:
|
||||||
dataSource, rErr := ResetReadTimeout(*ss.settings.DataSource)
|
dataSource, rErr := sqlUtils.ResetReadTimeout(*ss.settings.DataSource)
|
||||||
if rErr != nil {
|
if rErr != nil {
|
||||||
mlog.Fatal("Failed to reset read timeout from datasource.", mlog.Err(rErr), mlog.String("src", *ss.settings.DataSource))
|
mlog.Fatal("Failed to reset read timeout from datasource.", mlog.Err(rErr), mlog.String("src", *ss.settings.DataSource))
|
||||||
return nil, rErr
|
return nil, rErr
|
||||||
}
|
}
|
||||||
dataSource, err = AppendMultipleStatementsFlag(dataSource)
|
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
db, err2 := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
db, err2 := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, err2
|
return nil, err2
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||||
|
|
||||||
sq "github.com/mattermost/squirrel"
|
sq "github.com/mattermost/squirrel"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
"github.com/go-sql-driver/mysql"
|
||||||
@ -44,7 +46,6 @@ const (
|
|||||||
PGDuplicateObjectErrorCode = "42710"
|
PGDuplicateObjectErrorCode = "42710"
|
||||||
MySQLDuplicateObjectErrorCode = 1022
|
MySQLDuplicateObjectErrorCode = 1022
|
||||||
DBPingAttempts = 5
|
DBPingAttempts = 5
|
||||||
DBPingTimeoutSecs = 10
|
|
||||||
// This is a numerical version string by postgres. The format is
|
// This is a numerical version string by postgres. The format is
|
||||||
// 2 characters for major, minor, and patch version prior to 10.
|
// 2 characters for major, minor, and patch version prior to 10.
|
||||||
// After 10, it's major and minor only.
|
// After 10, it's major and minor only.
|
||||||
@ -241,58 +242,6 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface
|
|||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
|
|
||||||
// It also applies any database configuration settings that are required.
|
|
||||||
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
|
|
||||||
db, err := dbsql.Open(*settings.DriverName, dataSource)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to open SQL connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
|
|
||||||
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
|
|
||||||
|
|
||||||
logger = logger.With(
|
|
||||||
mlog.String("database", connType),
|
|
||||||
mlog.String("dataSource", sanitized),
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := 0; i < attempts; i++ {
|
|
||||||
logger.Info("Pinging SQL")
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
err = db.PingContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
if i == attempts-1 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
|
|
||||||
time.Sleep(DBPingTimeoutSecs * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(connType, replicaLagPrefix) {
|
|
||||||
// If this is a replica lag connection, we just open one connection.
|
|
||||||
//
|
|
||||||
// Arguably, if the query doesn't require a special credential, it does take up
|
|
||||||
// one extra connection from the replica DB. But falling back to the replica
|
|
||||||
// data source when the replica lag data source is null implies an ordering constraint
|
|
||||||
// which makes things brittle and is not a good design.
|
|
||||||
// If connections are an overhead, it is advised to use a connection pool.
|
|
||||||
db.SetMaxOpenConns(1)
|
|
||||||
db.SetMaxIdleConns(1)
|
|
||||||
} else {
|
|
||||||
db.SetMaxIdleConns(*settings.MaxIdleConns)
|
|
||||||
db.SetMaxOpenConns(*settings.MaxOpenConns)
|
|
||||||
}
|
|
||||||
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
|
|
||||||
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *SqlStore) SetContext(context context.Context) {
|
func (ss *SqlStore) SetContext(context context.Context) {
|
||||||
ss.context = context
|
ss.context = context
|
||||||
}
|
}
|
||||||
@ -314,13 +263,13 @@ func (ss *SqlStore) initConnection() error {
|
|||||||
// covers that already. Ideally we'd like to do this only for the upgrade
|
// covers that already. Ideally we'd like to do this only for the upgrade
|
||||||
// step. To be reviewed in MM-35789.
|
// step. To be reviewed in MM-35789.
|
||||||
var err error
|
var err error
|
||||||
dataSource, err = ResetReadTimeout(dataSource)
|
dataSource, err = sqlUtils.ResetReadTimeout(dataSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to reset read timeout from datasource")
|
return errors.Wrap(err, "failed to reset read timeout from datasource")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
handle, err := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
handle, err := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -338,7 +287,7 @@ func (ss *SqlStore) initConnection() error {
|
|||||||
ss.ReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceReplicas))
|
ss.ReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceReplicas))
|
||||||
for i, replica := range ss.settings.DataSourceReplicas {
|
for i, replica := range ss.settings.DataSourceReplicas {
|
||||||
ss.ReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
ss.ReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
||||||
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
|
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Initializing to be offline
|
// Initializing to be offline
|
||||||
ss.ReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
ss.ReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
||||||
@ -353,7 +302,7 @@ func (ss *SqlStore) initConnection() error {
|
|||||||
ss.searchReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceSearchReplicas))
|
ss.searchReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceSearchReplicas))
|
||||||
for i, replica := range ss.settings.DataSourceSearchReplicas {
|
for i, replica := range ss.settings.DataSourceSearchReplicas {
|
||||||
ss.searchReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
ss.searchReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
||||||
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
|
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Initializing to be offline
|
// Initializing to be offline
|
||||||
ss.searchReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
ss.searchReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
||||||
@ -370,7 +319,7 @@ func (ss *SqlStore) initConnection() error {
|
|||||||
if src.DataSource == nil {
|
if src.DataSource == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ss.replicaLagHandles[i], err = SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
|
ss.replicaLagHandles[i], err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mlog.Warn("Failed to setup replica lag handle. Skipping..", mlog.String("db", fmt.Sprintf(replicaLagPrefix+"-%d", i)), mlog.Err(err))
|
mlog.Warn("Failed to setup replica lag handle. Skipping..", mlog.String("db", fmt.Sprintf(replicaLagPrefix+"-%d", i)), mlog.Err(err))
|
||||||
continue
|
continue
|
||||||
@ -525,7 +474,7 @@ func (ss *SqlStore) monitorReplicas() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
handle, err := SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
|
handle, err := sqlUtils.SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mlog.Warn("Failed to setup connection. Skipping..", mlog.String("db", name), mlog.Err(err))
|
mlog.Warn("Failed to setup connection. Skipping..", mlog.String("db", name), mlog.Err(err))
|
||||||
return
|
return
|
||||||
|
@ -5,7 +5,6 @@ package sqlstore
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -16,8 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var escapeLikeSearchChar = []string{
|
var escapeLikeSearchChar = []string{
|
||||||
@ -183,57 +180,6 @@ func AppendBinaryFlag(buf []byte) []byte {
|
|||||||
return append([]byte{0x01}, buf...)
|
return append([]byte{0x01}, buf...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
|
|
||||||
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
|
|
||||||
config, err := mysql.ParseDSN(dataSource)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Params == nil {
|
|
||||||
config.Params = map[string]string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Params["multiStatements"] = "true"
|
|
||||||
return config.FormatDSN(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
|
|
||||||
func ResetReadTimeout(dataSource string) (string, error) {
|
|
||||||
config, err := mysql.ParseDSN(dataSource)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
config.ReadTimeout = 0
|
|
||||||
return config.FormatDSN(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeDataSource(driverName, dataSource string) (string, error) {
|
|
||||||
switch driverName {
|
|
||||||
case model.DatabaseDriverPostgres:
|
|
||||||
u, err := url.Parse(dataSource)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
u.User = url.UserPassword("****", "****")
|
|
||||||
params := u.Query()
|
|
||||||
params.Del("user")
|
|
||||||
params.Del("password")
|
|
||||||
u.RawQuery = params.Encode()
|
|
||||||
return u.String(), nil
|
|
||||||
case model.DatabaseDriverMysql:
|
|
||||||
cfg, err := mysql.ParseDSN(dataSource)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
cfg.User = "****"
|
|
||||||
cfg.Passwd = "****"
|
|
||||||
return cfg.FormatDSN(), nil
|
|
||||||
default:
|
|
||||||
return "", errors.New("invalid drivername. Not postgres or mysql.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxTokenSize = 50
|
const maxTokenSize = 50
|
||||||
|
|
||||||
// trimInput limits the string to a max size to prevent clogging up disk space
|
// trimInput limits the string to a max size to prevent clogging up disk space
|
||||||
|
@ -6,7 +6,6 @@ package sqlstore
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -134,72 +133,3 @@ func TestMySQLJSONArgs(t *testing.T) {
|
|||||||
assert.Equal(t, test.argString, argString)
|
assert.Equal(t, test.argString, argString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAppendMultipleStatementsFlag(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
Scenario string
|
|
||||||
DSN string
|
|
||||||
ExpectedDSN string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"Should append multiStatements param to the DSN path with existing params",
|
|
||||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
|
|
||||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Should append multiStatements param to the DSN path with no existing params",
|
|
||||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
|
||||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.Scenario, func(t *testing.T) {
|
|
||||||
res, err := AppendMultipleStatementsFlag(tc.DSN)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.ExpectedDSN, res)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeDataSource(t *testing.T) {
|
|
||||||
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
Original string
|
|
||||||
Sanitized string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
|
|
||||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
|
|
||||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
driver := model.DatabaseDriverPostgres
|
|
||||||
for _, tc := range testCases {
|
|
||||||
out, err := SanitizeDataSource(driver, tc.Original)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.Sanitized, out)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
Original string
|
|
||||||
Sanitized string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
|
|
||||||
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
driver := model.DatabaseDriverMysql
|
|
||||||
for _, tc := range testCases {
|
|
||||||
out, err := SanitizeDataSource(driver, tc.Original)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.Sanitized, out)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
@ -27,8 +29,6 @@ import (
|
|||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
|
|
||||||
|
|
||||||
"github.com/mattermost/morph/drivers"
|
"github.com/mattermost/morph/drivers"
|
||||||
ms "github.com/mattermost/morph/drivers/mysql"
|
ms "github.com/mattermost/morph/drivers/mysql"
|
||||||
ps "github.com/mattermost/morph/drivers/postgres"
|
ps "github.com/mattermost/morph/drivers/postgres"
|
||||||
@ -121,12 +121,12 @@ func (ds *DatabaseStore) initializeConfigurationsTable() error {
|
|||||||
var driver drivers.Driver
|
var driver drivers.Driver
|
||||||
switch ds.driverName {
|
switch ds.driverName {
|
||||||
case model.DatabaseDriverMysql:
|
case model.DatabaseDriverMysql:
|
||||||
dataSource, rErr := sqlstore.ResetReadTimeout(ds.dataSourceName)
|
dataSource, rErr := sqlUtils.ResetReadTimeout(ds.dataSourceName)
|
||||||
if rErr != nil {
|
if rErr != nil {
|
||||||
return fmt.Errorf("failed to reset read timeout from datasource: %w", rErr)
|
return fmt.Errorf("failed to reset read timeout from datasource: %w", rErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSource, err = sqlstore.AppendMultipleStatementsFlag(dataSource)
|
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -409,7 +409,7 @@ func (ds *DatabaseStore) RemoveFile(name string) error {
|
|||||||
func (ds *DatabaseStore) String() string {
|
func (ds *DatabaseStore) String() string {
|
||||||
// This is called during the running of MM, so we expect the parsing of DSN
|
// This is called during the running of MM, so we expect the parsing of DSN
|
||||||
// to be successful.
|
// to be successful.
|
||||||
sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn)
|
sanitized, _ := sqlUtils.SanitizeDataSource(ds.driverName, ds.originalDsn)
|
||||||
return sanitized
|
return sanitized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
126
server/public/utils/sql/sql_utils.go
Normal file
126
server/public/utils/sql/sql_utils.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
dbsql "database/sql"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DBPingTimeoutSecs = 10
|
||||||
|
|
||||||
|
replicaLagPrefix = "replica-lag"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
|
||||||
|
func ResetReadTimeout(dataSource string) (string, error) {
|
||||||
|
config, err := mysql.ParseDSN(dataSource)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
config.ReadTimeout = 0
|
||||||
|
return config.FormatDSN(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
|
||||||
|
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
|
||||||
|
config, err := mysql.ParseDSN(dataSource)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Params == nil {
|
||||||
|
config.Params = map[string]string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Params["multiStatements"] = "true"
|
||||||
|
return config.FormatDSN(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
|
||||||
|
// It also applies any database configuration settings that are required.
|
||||||
|
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
|
||||||
|
db, err := dbsql.Open(*settings.DriverName, dataSource)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to open SQL connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
|
||||||
|
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
|
||||||
|
|
||||||
|
logger = logger.With(
|
||||||
|
mlog.String("database", connType),
|
||||||
|
mlog.String("dataSource", sanitized),
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < attempts; i++ {
|
||||||
|
logger.Info("Pinging SQL")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err = db.PingContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if i == attempts-1 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
|
||||||
|
time.Sleep(DBPingTimeoutSecs * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(connType, replicaLagPrefix) {
|
||||||
|
// If this is a replica lag connection, we just open one connection.
|
||||||
|
//
|
||||||
|
// Arguably, if the query doesn't require a special credential, it does take up
|
||||||
|
// one extra connection from the replica DB. But falling back to the replica
|
||||||
|
// data source when the replica lag data source is null implies an ordering constraint
|
||||||
|
// which makes things brittle and is not a good design.
|
||||||
|
// If connections are an overhead, it is advised to use a connection pool.
|
||||||
|
db.SetMaxOpenConns(1)
|
||||||
|
db.SetMaxIdleConns(1)
|
||||||
|
} else {
|
||||||
|
db.SetMaxIdleConns(*settings.MaxIdleConns)
|
||||||
|
db.SetMaxOpenConns(*settings.MaxOpenConns)
|
||||||
|
}
|
||||||
|
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
|
||||||
|
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeDataSource(driverName, dataSource string) (string, error) {
|
||||||
|
switch driverName {
|
||||||
|
case model.DatabaseDriverPostgres:
|
||||||
|
u, err := url.Parse(dataSource)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
u.User = url.UserPassword("****", "****")
|
||||||
|
params := u.Query()
|
||||||
|
params.Del("user")
|
||||||
|
params.Del("password")
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
return u.String(), nil
|
||||||
|
case model.DatabaseDriverMysql:
|
||||||
|
cfg, err := mysql.ParseDSN(dataSource)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cfg.User = "****"
|
||||||
|
cfg.Passwd = "****"
|
||||||
|
return cfg.FormatDSN(), nil
|
||||||
|
default:
|
||||||
|
return "", errors.New("invalid drivername. Not postgres or mysql.")
|
||||||
|
}
|
||||||
|
}
|
109
server/public/utils/sql/sql_utils_test.go
Normal file
109
server/public/utils/sql/sql_utils_test.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppendMultipleStatementsFlag(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Scenario string
|
||||||
|
DSN string
|
||||||
|
ExpectedDSN string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"Should append multiStatements param to the DSN path with existing params",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Should append multiStatements param to the DSN path with no existing params",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Scenario, func(t *testing.T) {
|
||||||
|
res, err := AppendMultipleStatementsFlag(tc.DSN)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.ExpectedDSN, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetReadTimeout(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Scenario string
|
||||||
|
DSN string
|
||||||
|
ExpectedDSN string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"Should re move read timeout param from the DSN",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?readTimeout=30s",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Should change nothing as there is no read timeout param specified",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||||
|
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Scenario, func(t *testing.T) {
|
||||||
|
res, err := ResetReadTimeout(tc.DSN)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.ExpectedDSN, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeDataSource(t *testing.T) {
|
||||||
|
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Original string
|
||||||
|
Sanitized string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
|
||||||
|
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
|
||||||
|
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
driver := model.DatabaseDriverPostgres
|
||||||
|
for _, tc := range testCases {
|
||||||
|
out, err := SanitizeDataSource(driver, tc.Original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.Sanitized, out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Original string
|
||||||
|
Sanitized string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
|
||||||
|
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
driver := model.DatabaseDriverMysql
|
||||||
|
for _, tc := range testCases {
|
||||||
|
out, err := SanitizeDataSource(driver, tc.Original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.Sanitized, out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user