Moved some common SQL function to public utils as they are used in plugins (#26412)

* Moved some common SQL function tu public utls as they are used in plugins

* goimported file

* Added tests

* Created sub-package

* MOved SetupConnection to public sql utils
This commit is contained in:
Harshil Sharma 2024-03-11 09:54:23 +05:30 committed by GitHub
parent f34445a6f4
commit 4fda7e6f34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 253 additions and 191 deletions

View File

@ -12,6 +12,8 @@ import (
"strconv"
"sync"
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/db"
@ -119,16 +121,16 @@ func (ss *SqlStore) initMorph(dryRun bool) (*morph.Morph, error) {
var driver drivers.Driver
switch ss.DriverName() {
case model.DatabaseDriverMysql:
dataSource, rErr := ResetReadTimeout(*ss.settings.DataSource)
dataSource, rErr := sqlUtils.ResetReadTimeout(*ss.settings.DataSource)
if rErr != nil {
mlog.Fatal("Failed to reset read timeout from datasource.", mlog.Err(rErr), mlog.String("src", *ss.settings.DataSource))
return nil, rErr
}
dataSource, err = AppendMultipleStatementsFlag(dataSource)
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
if err != nil {
return nil, err
}
db, err2 := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
db, err2 := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
if err2 != nil {
return nil, err2
}

View File

@ -15,6 +15,8 @@ import (
"sync/atomic"
"time"
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
sq "github.com/mattermost/squirrel"
"github.com/go-sql-driver/mysql"
@ -44,7 +46,6 @@ const (
PGDuplicateObjectErrorCode = "42710"
MySQLDuplicateObjectErrorCode = 1022
DBPingAttempts = 5
DBPingTimeoutSecs = 10
// This is a numerical version string by postgres. The format is
// 2 characters for major, minor, and patch version prior to 10.
// After 10, it's major and minor only.
@ -241,58 +242,6 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface
return store, nil
}
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
// It also applies any database configuration settings that are required.
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
db, err := dbsql.Open(*settings.DriverName, dataSource)
if err != nil {
return nil, errors.Wrap(err, "failed to open SQL connection")
}
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
logger = logger.With(
mlog.String("database", connType),
mlog.String("dataSource", sanitized),
)
for i := 0; i < attempts; i++ {
logger.Info("Pinging SQL")
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
defer cancel()
err = db.PingContext(ctx)
if err != nil {
if i == attempts-1 {
return nil, err
}
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
time.Sleep(DBPingTimeoutSecs * time.Second)
continue
}
break
}
if strings.HasPrefix(connType, replicaLagPrefix) {
// If this is a replica lag connection, we just open one connection.
//
// Arguably, if the query doesn't require a special credential, it does take up
// one extra connection from the replica DB. But falling back to the replica
// data source when the replica lag data source is null implies an ordering constraint
// which makes things brittle and is not a good design.
// If connections are an overhead, it is advised to use a connection pool.
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxIdleConns(*settings.MaxIdleConns)
db.SetMaxOpenConns(*settings.MaxOpenConns)
}
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
return db, nil
}
func (ss *SqlStore) SetContext(context context.Context) {
ss.context = context
}
@ -314,13 +263,13 @@ func (ss *SqlStore) initConnection() error {
// covers that already. Ideally we'd like to do this only for the upgrade
// step. To be reviewed in MM-35789.
var err error
dataSource, err = ResetReadTimeout(dataSource)
dataSource, err = sqlUtils.ResetReadTimeout(dataSource)
if err != nil {
return errors.Wrap(err, "failed to reset read timeout from datasource")
}
}
handle, err := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
handle, err := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts)
if err != nil {
return err
}
@ -338,7 +287,7 @@ func (ss *SqlStore) initConnection() error {
ss.ReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceReplicas))
for i, replica := range ss.settings.DataSourceReplicas {
ss.ReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts)
if err != nil {
// Initializing to be offline
ss.ReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
@ -353,7 +302,7 @@ func (ss *SqlStore) initConnection() error {
ss.searchReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceSearchReplicas))
for i, replica := range ss.settings.DataSourceSearchReplicas {
ss.searchReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{}
handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts)
if err != nil {
// Initializing to be offline
ss.searchReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}})
@ -370,7 +319,7 @@ func (ss *SqlStore) initConnection() error {
if src.DataSource == nil {
continue
}
ss.replicaLagHandles[i], err = SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
ss.replicaLagHandles[i], err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts)
if err != nil {
mlog.Warn("Failed to setup replica lag handle. Skipping..", mlog.String("db", fmt.Sprintf(replicaLagPrefix+"-%d", i)), mlog.Err(err))
continue
@ -525,7 +474,7 @@ func (ss *SqlStore) monitorReplicas() {
return
}
handle, err := SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
handle, err := sqlUtils.SetupConnection(ss.Logger(), name, dsn, ss.settings, 1)
if err != nil {
mlog.Warn("Failed to setup connection. Skipping..", mlog.String("db", name), mlog.Err(err))
return

View File

@ -5,7 +5,6 @@ package sqlstore
import (
"database/sql"
"errors"
"io"
"net/url"
"strconv"
@ -16,8 +15,6 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/go-sql-driver/mysql"
)
var escapeLikeSearchChar = []string{
@ -183,57 +180,6 @@ func AppendBinaryFlag(buf []byte) []byte {
return append([]byte{0x01}, buf...)
}
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
config, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
if config.Params == nil {
config.Params = map[string]string{}
}
config.Params["multiStatements"] = "true"
return config.FormatDSN(), nil
}
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
func ResetReadTimeout(dataSource string) (string, error) {
config, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
config.ReadTimeout = 0
return config.FormatDSN(), nil
}
func SanitizeDataSource(driverName, dataSource string) (string, error) {
switch driverName {
case model.DatabaseDriverPostgres:
u, err := url.Parse(dataSource)
if err != nil {
return "", err
}
u.User = url.UserPassword("****", "****")
params := u.Query()
params.Del("user")
params.Del("password")
u.RawQuery = params.Encode()
return u.String(), nil
case model.DatabaseDriverMysql:
cfg, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
cfg.User = "****"
cfg.Passwd = "****"
return cfg.FormatDSN(), nil
default:
return "", errors.New("invalid drivername. Not postgres or mysql.")
}
}
const maxTokenSize = 50
// trimInput limits the string to a max size to prevent clogging up disk space

View File

@ -6,7 +6,6 @@ package sqlstore
import (
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -134,72 +133,3 @@ func TestMySQLJSONArgs(t *testing.T) {
assert.Equal(t, test.argString, argString)
}
}
func TestAppendMultipleStatementsFlag(t *testing.T) {
testCases := []struct {
Scenario string
DSN string
ExpectedDSN string
}{
{
"Should append multiStatements param to the DSN path with existing params",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
},
{
"Should append multiStatements param to the DSN path with no existing params",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
},
}
for _, tc := range testCases {
t.Run(tc.Scenario, func(t *testing.T) {
res, err := AppendMultipleStatementsFlag(tc.DSN)
require.NoError(t, err)
assert.Equal(t, tc.ExpectedDSN, res)
})
}
}
func TestSanitizeDataSource(t *testing.T) {
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
{
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
}
driver := model.DatabaseDriverPostgres
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
},
}
driver := model.DatabaseDriverMysql
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})
}

View File

@ -15,6 +15,8 @@ import (
"path/filepath"
"strings"
sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
@ -27,8 +29,6 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
"github.com/mattermost/morph/drivers"
ms "github.com/mattermost/morph/drivers/mysql"
ps "github.com/mattermost/morph/drivers/postgres"
@ -121,12 +121,12 @@ func (ds *DatabaseStore) initializeConfigurationsTable() error {
var driver drivers.Driver
switch ds.driverName {
case model.DatabaseDriverMysql:
dataSource, rErr := sqlstore.ResetReadTimeout(ds.dataSourceName)
dataSource, rErr := sqlUtils.ResetReadTimeout(ds.dataSourceName)
if rErr != nil {
return fmt.Errorf("failed to reset read timeout from datasource: %w", rErr)
}
dataSource, err = sqlstore.AppendMultipleStatementsFlag(dataSource)
dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource)
if err != nil {
return err
}
@ -409,7 +409,7 @@ func (ds *DatabaseStore) RemoveFile(name string) error {
func (ds *DatabaseStore) String() string {
// This is called during the running of MM, so we expect the parsing of DSN
// to be successful.
sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn)
sanitized, _ := sqlUtils.SanitizeDataSource(ds.driverName, ds.originalDsn)
return sanitized
}

View File

@ -0,0 +1,126 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sql
import (
"context"
dbsql "database/sql"
"net/url"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/pkg/errors"
)
const (
DBPingTimeoutSecs = 10
replicaLagPrefix = "replica-lag"
)
// ResetReadTimeout removes the timeout constraint from the MySQL dsn.
func ResetReadTimeout(dataSource string) (string, error) {
config, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
config.ReadTimeout = 0
return config.FormatDSN(), nil
}
// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work.
func AppendMultipleStatementsFlag(dataSource string) (string, error) {
config, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
if config.Params == nil {
config.Params = map[string]string{}
}
config.Params["multiStatements"] = "true"
return config.FormatDSN(), nil
}
// SetupConnection sets up the connection to the database and pings it to make sure it's alive.
// It also applies any database configuration settings that are required.
func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) {
db, err := dbsql.Open(*settings.DriverName, dataSource)
if err != nil {
return nil, errors.Wrap(err, "failed to open SQL connection")
}
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
logger = logger.With(
mlog.String("database", connType),
mlog.String("dataSource", sanitized),
)
for i := 0; i < attempts; i++ {
logger.Info("Pinging SQL")
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
defer cancel()
err = db.PingContext(ctx)
if err != nil {
if i == attempts-1 {
return nil, err
}
logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err))
time.Sleep(DBPingTimeoutSecs * time.Second)
continue
}
break
}
if strings.HasPrefix(connType, replicaLagPrefix) {
// If this is a replica lag connection, we just open one connection.
//
// Arguably, if the query doesn't require a special credential, it does take up
// one extra connection from the replica DB. But falling back to the replica
// data source when the replica lag data source is null implies an ordering constraint
// which makes things brittle and is not a good design.
// If connections are an overhead, it is advised to use a connection pool.
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxIdleConns(*settings.MaxIdleConns)
db.SetMaxOpenConns(*settings.MaxOpenConns)
}
db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond)
db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond)
return db, nil
}
func SanitizeDataSource(driverName, dataSource string) (string, error) {
switch driverName {
case model.DatabaseDriverPostgres:
u, err := url.Parse(dataSource)
if err != nil {
return "", err
}
u.User = url.UserPassword("****", "****")
params := u.Query()
params.Del("user")
params.Del("password")
u.RawQuery = params.Encode()
return u.String(), nil
case model.DatabaseDriverMysql:
cfg, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
cfg.User = "****"
cfg.Passwd = "****"
return cfg.FormatDSN(), nil
default:
return "", errors.New("invalid drivername. Not postgres or mysql.")
}
}

View File

@ -0,0 +1,109 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sql
import (
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAppendMultipleStatementsFlag(t *testing.T) {
testCases := []struct {
Scenario string
DSN string
ExpectedDSN string
}{
{
"Should append multiStatements param to the DSN path with existing params",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true",
},
{
"Should append multiStatements param to the DSN path with no existing params",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true",
},
}
for _, tc := range testCases {
t.Run(tc.Scenario, func(t *testing.T) {
res, err := AppendMultipleStatementsFlag(tc.DSN)
require.NoError(t, err)
assert.Equal(t, tc.ExpectedDSN, res)
})
}
}
func TestResetReadTimeout(t *testing.T) {
testCases := []struct {
Scenario string
DSN string
ExpectedDSN string
}{
{
"Should re move read timeout param from the DSN",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?readTimeout=30s",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
},
{
"Should change nothing as there is no read timeout param specified",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
"user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost",
},
}
for _, tc := range testCases {
t.Run(tc.Scenario, func(t *testing.T) {
res, err := ResetReadTimeout(tc.DSN)
require.NoError(t, err)
assert.Equal(t, tc.ExpectedDSN, res)
})
}
}
func TestSanitizeDataSource(t *testing.T) {
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
{
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
}
driver := model.DatabaseDriverPostgres
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})
t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
},
}
driver := model.DatabaseDriverMysql
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})
}