mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
Moved some common SQL function to public utils as they are used in plugins (#26412)
* Moved some common SQL function tu public utls as they are used in plugins * goimported file * Added tests * Created sub-package * MOved SetupConnection to public sql utils
This commit is contained in:
parent
f34445a6f4
commit
4fda7e6f34
@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/db"
|
||||
@ -119,16 +121,16 @@ func (ss *SqlStore) initMorph(dryRun bool) (*morph.Morph, error) {
|
||||
var driver drivers.Driver
|
||||
switch ss.DriverName() {
|
||||
case model.DatabaseDriverMysql:
|
||||
dataSource, rErr := ResetReadTimeout(*ss.settings.DataSource)
|
||||
dataSource, rErr := sqlUtils.ResetReadTimeout(*ss.settings.DataSource)
|
||||
if rErr != nil {
|
||||
mlog.Fatal("Failed to reset read timeout from datasource.", mlog.Err(rErr), mlog.String("src", *ss.settings.DataSource))
|
||||
return nil, rErr
|
||||
}
|
||||
dataSource, err = AppendMultipleStatementsFlag(dataSource)
|
||||
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err2 := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||
db, err2 := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||
|
||||
sq "github.com/mattermost/squirrel"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
@ -44,7 +46,6 @@ const (
|
||||
PGDuplicateObjectErrorCode = "42710"
|
||||
MySQLDuplicateObjectErrorCode = 1022
|
||||
DBPingAttempts = 5
|
||||
DBPingTimeoutSecs = 10
|
||||
// This is a numerical version string by postgres. The format is
|
||||
// 2 characters for major, minor, and patch version prior to 10.
|
||||
// After 10, it's major and minor only.
|
||||
@ -241,58 +242,6 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
|
||||
// It also applies any database configuration settings that are required.
|
||||
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
|
||||
db, err := dbsql.Open(*settings.DriverName, dataSource)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open SQL connection")
|
||||
}
|
||||
|
||||
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
|
||||
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
|
||||
|
||||
logger = logger.With(
|
||||
mlog.String("database", connType),
|
||||
mlog.String("dataSource", sanitized),
|
||||
)
|
||||
|
||||
for i := 0; i < attempts; i++ {
|
||||
logger.Info("Pinging SQL")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
|
||||
defer cancel()
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
if i == attempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
|
||||
time.Sleep(DBPingTimeoutSecs * time.Second)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if strings.HasPrefix(connType, replicaLagPrefix) {
|
||||
// If this is a replica lag connection, we just open one connection.
|
||||
//
|
||||
// Arguably, if the query doesn't require a special credential, it does take up
|
||||
// one extra connection from the replica DB. But falling back to the replica
|
||||
// data source when the replica lag data source is null implies an ordering constraint
|
||||
// which makes things brittle and is not a good design.
|
||||
// If connections are an overhead, it is advised to use a connection pool.
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
} else {
|
||||
db.SetMaxIdleConns(*settings.MaxIdleConns)
|
||||
db.SetMaxOpenConns(*settings.MaxOpenConns)
|
||||
}
|
||||
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
|
||||
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (ss *SqlStore) SetContext(context context.Context) {
|
||||
ss.context = context
|
||||
}
|
||||
@ -314,13 +263,13 @@ func (ss *SqlStore) initConnection() error {
|
||||
// covers that already. Ideally we'd like to do this only for the upgrade
|
||||
// step. To be reviewed in MM-35789.
|
||||
var err error
|
||||
dataSource, err = ResetReadTimeout(dataSource)
|
||||
dataSource, err = sqlUtils.ResetReadTimeout(dataSource)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to reset read timeout from datasource")
|
||||
}
|
||||
}
|
||||
|
||||
handle, err := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||
handle, err := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -338,7 +287,7 @@ func (ss *SqlStore) initConnection() error {
|
||||
ss.ReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceReplicas))
|
||||
for i, replica := range ss.settings.DataSourceReplicas {
|
||||
ss.ReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
||||
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||
if err != nil {
|
||||
// Initializing to be offline
|
||||
ss.ReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
||||
@ -353,7 +302,7 @@ func (ss *SqlStore) initConnection() error {
|
||||
ss.searchReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceSearchReplicas))
|
||||
for i, replica := range ss.settings.DataSourceSearchReplicas {
|
||||
ss.searchReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
|
||||
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
|
||||
if err != nil {
|
||||
// Initializing to be offline
|
||||
ss.searchReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
|
||||
@ -370,7 +319,7 @@ func (ss *SqlStore) initConnection() error {
|
||||
if src.DataSource == nil {
|
||||
continue
|
||||
}
|
||||
ss.replicaLagHandles[i], err = SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
|
||||
ss.replicaLagHandles[i], err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
|
||||
if err != nil {
|
||||
mlog.Warn("Failed to setup replica lag handle. Skipping..", mlog.String("db", fmt.Sprintf(replicaLagPrefix+"-%d", i)), mlog.Err(err))
|
||||
continue
|
||||
@ -525,7 +474,7 @@ func (ss *SqlStore) monitorReplicas() {
|
||||
return
|
||||
}
|
||||
|
||||
handle, err := SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
|
||||
handle, err := sqlUtils.SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
|
||||
if err != nil {
|
||||
mlog.Warn("Failed to setup connection. Skipping..", mlog.String("db", name), mlog.Err(err))
|
||||
return
|
||||
|
@ -5,7 +5,6 @@ package sqlstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@ -16,8 +15,6 @@ import (
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
var escapeLikeSearchChar = []string{
|
||||
@ -183,57 +180,6 @@ func AppendBinaryFlag(buf []byte) []byte {
|
||||
return append([]byte{0x01}, buf...)
|
||||
}
|
||||
|
||||
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
|
||||
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
|
||||
config, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if config.Params == nil {
|
||||
config.Params = map[string]string{}
|
||||
}
|
||||
|
||||
config.Params["multiStatements"] = "true"
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
|
||||
func ResetReadTimeout(dataSource string) (string, error) {
|
||||
config, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
config.ReadTimeout = 0
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
func SanitizeDataSource(driverName, dataSource string) (string, error) {
|
||||
switch driverName {
|
||||
case model.DatabaseDriverPostgres:
|
||||
u, err := url.Parse(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
u.User = url.UserPassword("****", "****")
|
||||
params := u.Query()
|
||||
params.Del("user")
|
||||
params.Del("password")
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String(), nil
|
||||
case model.DatabaseDriverMysql:
|
||||
cfg, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cfg.User = "****"
|
||||
cfg.Passwd = "****"
|
||||
return cfg.FormatDSN(), nil
|
||||
default:
|
||||
return "", errors.New("invalid drivername. Not postgres or mysql.")
|
||||
}
|
||||
}
|
||||
|
||||
const maxTokenSize = 50
|
||||
|
||||
// trimInput limits the string to a max size to prevent clogging up disk space
|
||||
|
@ -6,7 +6,6 @@ package sqlstore
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -134,72 +133,3 @@ func TestMySQLJSONArgs(t *testing.T) {
|
||||
assert.Equal(t, test.argString, argString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendMultipleStatementsFlag(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Scenario string
|
||||
DSN string
|
||||
ExpectedDSN string
|
||||
}{
|
||||
{
|
||||
"Should append multiStatements param to the DSN path with existing params",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
|
||||
},
|
||||
{
|
||||
"Should append multiStatements param to the DSN path with no existing params",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Scenario, func(t *testing.T) {
|
||||
res, err := AppendMultipleStatementsFlag(tc.DSN)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.ExpectedDSN, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeDataSource(t *testing.T) {
|
||||
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Original string
|
||||
Sanitized string
|
||||
}{
|
||||
{
|
||||
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
|
||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||
},
|
||||
{
|
||||
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
|
||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||
},
|
||||
}
|
||||
driver := model.DatabaseDriverPostgres
|
||||
for _, tc := range testCases {
|
||||
out, err := SanitizeDataSource(driver, tc.Original)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.Sanitized, out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Original string
|
||||
Sanitized string
|
||||
}{
|
||||
{
|
||||
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
|
||||
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
|
||||
},
|
||||
}
|
||||
driver := model.DatabaseDriverMysql
|
||||
for _, tc := range testCases {
|
||||
out, err := SanitizeDataSource(driver, tc.Original)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.Sanitized, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@ -27,8 +29,6 @@ import (
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
|
||||
|
||||
"github.com/mattermost/morph/drivers"
|
||||
ms "github.com/mattermost/morph/drivers/mysql"
|
||||
ps "github.com/mattermost/morph/drivers/postgres"
|
||||
@ -121,12 +121,12 @@ func (ds *DatabaseStore) initializeConfigurationsTable() error {
|
||||
var driver drivers.Driver
|
||||
switch ds.driverName {
|
||||
case model.DatabaseDriverMysql:
|
||||
dataSource, rErr := sqlstore.ResetReadTimeout(ds.dataSourceName)
|
||||
dataSource, rErr := sqlUtils.ResetReadTimeout(ds.dataSourceName)
|
||||
if rErr != nil {
|
||||
return fmt.Errorf("failed to reset read timeout from datasource: %w", rErr)
|
||||
}
|
||||
|
||||
dataSource, err = sqlstore.AppendMultipleStatementsFlag(dataSource)
|
||||
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -409,7 +409,7 @@ func (ds *DatabaseStore) RemoveFile(name string) error {
|
||||
func (ds *DatabaseStore) String() string {
|
||||
// This is called during the running of MM, so we expect the parsing of DSN
|
||||
// to be successful.
|
||||
sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn)
|
||||
sanitized, _ := sqlUtils.SanitizeDataSource(ds.driverName, ds.originalDsn)
|
||||
return sanitized
|
||||
}
|
||||
|
||||
|
126
server/public/utils/sql/sql_utils.go
Normal file
126
server/public/utils/sql/sql_utils.go
Normal file
@ -0,0 +1,126 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
dbsql "database/sql"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
DBPingTimeoutSecs = 10
|
||||
|
||||
replicaLagPrefix = "replica-lag"
|
||||
)
|
||||
|
||||
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
|
||||
func ResetReadTimeout(dataSource string) (string, error) {
|
||||
config, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
config.ReadTimeout = 0
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
|
||||
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
|
||||
config, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if config.Params == nil {
|
||||
config.Params = map[string]string{}
|
||||
}
|
||||
|
||||
config.Params["multiStatements"] = "true"
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
|
||||
// It also applies any database configuration settings that are required.
|
||||
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
|
||||
db, err := dbsql.Open(*settings.DriverName, dataSource)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open SQL connection")
|
||||
}
|
||||
|
||||
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
|
||||
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
|
||||
|
||||
logger = logger.With(
|
||||
mlog.String("database", connType),
|
||||
mlog.String("dataSource", sanitized),
|
||||
)
|
||||
|
||||
for i := 0; i < attempts; i++ {
|
||||
logger.Info("Pinging SQL")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
|
||||
defer cancel()
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
if i == attempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
|
||||
time.Sleep(DBPingTimeoutSecs * time.Second)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if strings.HasPrefix(connType, replicaLagPrefix) {
|
||||
// If this is a replica lag connection, we just open one connection.
|
||||
//
|
||||
// Arguably, if the query doesn't require a special credential, it does take up
|
||||
// one extra connection from the replica DB. But falling back to the replica
|
||||
// data source when the replica lag data source is null implies an ordering constraint
|
||||
// which makes things brittle and is not a good design.
|
||||
// If connections are an overhead, it is advised to use a connection pool.
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
} else {
|
||||
db.SetMaxIdleConns(*settings.MaxIdleConns)
|
||||
db.SetMaxOpenConns(*settings.MaxOpenConns)
|
||||
}
|
||||
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
|
||||
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func SanitizeDataSource(driverName, dataSource string) (string, error) {
|
||||
switch driverName {
|
||||
case model.DatabaseDriverPostgres:
|
||||
u, err := url.Parse(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
u.User = url.UserPassword("****", "****")
|
||||
params := u.Query()
|
||||
params.Del("user")
|
||||
params.Del("password")
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String(), nil
|
||||
case model.DatabaseDriverMysql:
|
||||
cfg, err := mysql.ParseDSN(dataSource)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cfg.User = "****"
|
||||
cfg.Passwd = "****"
|
||||
return cfg.FormatDSN(), nil
|
||||
default:
|
||||
return "", errors.New("invalid drivername. Not postgres or mysql.")
|
||||
}
|
||||
}
|
109
server/public/utils/sql/sql_utils_test.go
Normal file
109
server/public/utils/sql/sql_utils_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppendMultipleStatementsFlag(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Scenario string
|
||||
DSN string
|
||||
ExpectedDSN string
|
||||
}{
|
||||
{
|
||||
"Should append multiStatements param to the DSN path with existing params",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
|
||||
},
|
||||
{
|
||||
"Should append multiStatements param to the DSN path with no existing params",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Scenario, func(t *testing.T) {
|
||||
res, err := AppendMultipleStatementsFlag(tc.DSN)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.ExpectedDSN, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetReadTimeout(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Scenario string
|
||||
DSN string
|
||||
ExpectedDSN string
|
||||
}{
|
||||
{
|
||||
"Should re move read timeout param from the DSN",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?readTimeout=30s",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||
},
|
||||
{
|
||||
"Should change nothing as there is no read timeout param specified",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Scenario, func(t *testing.T) {
|
||||
res, err := ResetReadTimeout(tc.DSN)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.ExpectedDSN, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeDataSource(t *testing.T) {
|
||||
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Original string
|
||||
Sanitized string
|
||||
}{
|
||||
{
|
||||
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
|
||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||
},
|
||||
{
|
||||
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
|
||||
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
|
||||
},
|
||||
}
|
||||
driver := model.DatabaseDriverPostgres
|
||||
for _, tc := range testCases {
|
||||
out, err := SanitizeDataSource(driver, tc.Original)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.Sanitized, out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Original string
|
||||
Sanitized string
|
||||
}{
|
||||
{
|
||||
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
|
||||
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
|
||||
},
|
||||
}
|
||||
driver := model.DatabaseDriverMysql
|
||||
for _, tc := range testCases {
|
||||
out, err := SanitizeDataSource(driver, tc.Original)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.Sanitized, out)
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user