From 4fda7e6f3440e6fa443e21882ddc45fb89948a9d Mon Sep 17 00:00:00 2001 From: Harshil Sharma <18575143+harshilsharma63@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:54:23 +0530 Subject: [PATCH] Moved some common SQL function to public utils as they are used in plugins (#26412) * Moved some common SQL function tu public utls as they are used in plugins * goimported file * Added tests * Created sub-package * MOved SetupConnection to public sql utils --- server/channels/store/sqlstore/migrate.go | 8 +- server/channels/store/sqlstore/store.go | 67 ++-------- server/channels/store/sqlstore/utils.go | 54 -------- server/channels/store/sqlstore/utils_test.go | 70 ----------- server/config/database.go | 10 +- server/public/utils/sql/sql_utils.go | 126 +++++++++++++++++++ server/public/utils/sql/sql_utils_test.go | 109 ++++++++++++++++ 7 files changed, 253 insertions(+), 191 deletions(-) create mode 100644 server/public/utils/sql/sql_utils.go create mode 100644 server/public/utils/sql/sql_utils_test.go diff --git a/server/channels/store/sqlstore/migrate.go b/server/channels/store/sqlstore/migrate.go index d17695fece..7a00fa0691 100644 --- a/server/channels/store/sqlstore/migrate.go +++ b/server/channels/store/sqlstore/migrate.go @@ -12,6 +12,8 @@ import ( "strconv" "sync" + sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql" + "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/db" @@ -119,16 +121,16 @@ func (ss *SqlStore) initMorph(dryRun bool) (*morph.Morph, error) { var driver drivers.Driver switch ss.DriverName() { case model.DatabaseDriverMysql: - dataSource, rErr := ResetReadTimeout(*ss.settings.DataSource) + dataSource, rErr := sqlUtils.ResetReadTimeout(*ss.settings.DataSource) if rErr != nil { mlog.Fatal("Failed to reset read timeout from datasource.", mlog.Err(rErr), mlog.String("src", *ss.settings.DataSource)) return nil, rErr } - dataSource, err = AppendMultipleStatementsFlag(dataSource) + dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource) if err != nil { return nil, err } - db, err2 := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts) + db, err2 := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts) if err2 != nil { return nil, err2 } diff --git a/server/channels/store/sqlstore/store.go b/server/channels/store/sqlstore/store.go index 4369e6edd8..a48c05908b 100644 --- a/server/channels/store/sqlstore/store.go +++ b/server/channels/store/sqlstore/store.go @@ -15,6 +15,8 @@ import ( "sync/atomic" "time" + sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql" + sq "github.com/mattermost/squirrel" "github.com/go-sql-driver/mysql" @@ -44,7 +46,6 @@ const ( PGDuplicateObjectErrorCode = "42710" MySQLDuplicateObjectErrorCode = 1022 DBPingAttempts = 5 - DBPingTimeoutSecs = 10 // This is a numerical version string by postgres. The format is // 2 characters for major, minor, and patch version prior to 10. // After 10, it's major and minor only. @@ -241,58 +242,6 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface return store, nil } -// SetupConnection sets up the connection to the database and pings it to make sure it's alive. -// It also applies any database configuration settings that are required. -func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) { - db, err := dbsql.Open(*settings.DriverName, dataSource) - if err != nil { - return nil, errors.Wrap(err, "failed to open SQL connection") - } - - // At this point, we have passed sql.Open, so we deliberately ignore any errors. - sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource) - - logger = logger.With( - mlog.String("database", connType), - mlog.String("dataSource", sanitized), - ) - - for i := 0; i < attempts; i++ { - logger.Info("Pinging SQL") - ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second) - defer cancel() - err = db.PingContext(ctx) - if err != nil { - if i == attempts-1 { - return nil, err - } - logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err)) - time.Sleep(DBPingTimeoutSecs * time.Second) - continue - } - break - } - - if strings.HasPrefix(connType, replicaLagPrefix) { - // If this is a replica lag connection, we just open one connection. - // - // Arguably, if the query doesn't require a special credential, it does take up - // one extra connection from the replica DB. But falling back to the replica - // data source when the replica lag data source is null implies an ordering constraint - // which makes things brittle and is not a good design. - // If connections are an overhead, it is advised to use a connection pool. - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(1) - } else { - db.SetMaxIdleConns(*settings.MaxIdleConns) - db.SetMaxOpenConns(*settings.MaxOpenConns) - } - db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond) - db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond) - - return db, nil -} - func (ss *SqlStore) SetContext(context context.Context) { ss.context = context } @@ -314,13 +263,13 @@ func (ss *SqlStore) initConnection() error { // covers that already. Ideally we'd like to do this only for the upgrade // step. To be reviewed in MM-35789. var err error - dataSource, err = ResetReadTimeout(dataSource) + dataSource, err = sqlUtils.ResetReadTimeout(dataSource) if err != nil { return errors.Wrap(err, "failed to reset read timeout from datasource") } } - handle, err := SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts) + handle, err := sqlUtils.SetupConnection(ss.Logger(), "master", dataSource, ss.settings, DBPingAttempts) if err != nil { return err } @@ -338,7 +287,7 @@ func (ss *SqlStore) initConnection() error { ss.ReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceReplicas)) for i, replica := range ss.settings.DataSourceReplicas { ss.ReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{} - handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts) + handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("replica-%v", i), replica, ss.settings, DBPingAttempts) if err != nil { // Initializing to be offline ss.ReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}}) @@ -353,7 +302,7 @@ func (ss *SqlStore) initConnection() error { ss.searchReplicaXs = make([]*atomic.Pointer[sqlxDBWrapper], len(ss.settings.DataSourceSearchReplicas)) for i, replica := range ss.settings.DataSourceSearchReplicas { ss.searchReplicaXs[i] = &atomic.Pointer[sqlxDBWrapper]{} - handle, err = SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts) + handle, err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf("search-replica-%v", i), replica, ss.settings, DBPingAttempts) if err != nil { // Initializing to be offline ss.searchReplicaXs[i].Store(&sqlxDBWrapper{isOnline: &atomic.Bool{}}) @@ -370,7 +319,7 @@ func (ss *SqlStore) initConnection() error { if src.DataSource == nil { continue } - ss.replicaLagHandles[i], err = SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts) + ss.replicaLagHandles[i], err = sqlUtils.SetupConnection(ss.Logger(), fmt.Sprintf(replicaLagPrefix+"-%d", i), *src.DataSource, ss.settings, DBPingAttempts) if err != nil { mlog.Warn("Failed to setup replica lag handle. Skipping..", mlog.String("db", fmt.Sprintf(replicaLagPrefix+"-%d", i)), mlog.Err(err)) continue @@ -525,7 +474,7 @@ func (ss *SqlStore) monitorReplicas() { return } - handle, err := SetupConnection(ss.Logger(), name, dsn, ss.settings, 1) + handle, err := sqlUtils.SetupConnection(ss.Logger(), name, dsn, ss.settings, 1) if err != nil { mlog.Warn("Failed to setup connection. Skipping..", mlog.String("db", name), mlog.Err(err)) return diff --git a/server/channels/store/sqlstore/utils.go b/server/channels/store/sqlstore/utils.go index aebdc14a2d..75357f4c8a 100644 --- a/server/channels/store/sqlstore/utils.go +++ b/server/channels/store/sqlstore/utils.go @@ -5,7 +5,6 @@ package sqlstore import ( "database/sql" - "errors" "io" "net/url" "strconv" @@ -16,8 +15,6 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" - - "github.com/go-sql-driver/mysql" ) var escapeLikeSearchChar = []string{ @@ -183,57 +180,6 @@ func AppendBinaryFlag(buf []byte) []byte { return append([]byte{0x01}, buf...) } -// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work. -func AppendMultipleStatementsFlag(dataSource string) (string, error) { - config, err := mysql.ParseDSN(dataSource) - if err != nil { - return "", err - } - - if config.Params == nil { - config.Params = map[string]string{} - } - - config.Params["multiStatements"] = "true" - return config.FormatDSN(), nil -} - -// ResetReadTimeout removes the timeout constraint from the MySQL dsn. -func ResetReadTimeout(dataSource string) (string, error) { - config, err := mysql.ParseDSN(dataSource) - if err != nil { - return "", err - } - config.ReadTimeout = 0 - return config.FormatDSN(), nil -} - -func SanitizeDataSource(driverName, dataSource string) (string, error) { - switch driverName { - case model.DatabaseDriverPostgres: - u, err := url.Parse(dataSource) - if err != nil { - return "", err - } - u.User = url.UserPassword("****", "****") - params := u.Query() - params.Del("user") - params.Del("password") - u.RawQuery = params.Encode() - return u.String(), nil - case model.DatabaseDriverMysql: - cfg, err := mysql.ParseDSN(dataSource) - if err != nil { - return "", err - } - cfg.User = "****" - cfg.Passwd = "****" - return cfg.FormatDSN(), nil - default: - return "", errors.New("invalid drivername. Not postgres or mysql.") - } -} - const maxTokenSize = 50 // trimInput limits the string to a max size to prevent clogging up disk space diff --git a/server/channels/store/sqlstore/utils_test.go b/server/channels/store/sqlstore/utils_test.go index d110803841..6ec6a3b2fe 100644 --- a/server/channels/store/sqlstore/utils_test.go +++ b/server/channels/store/sqlstore/utils_test.go @@ -6,7 +6,6 @@ package sqlstore import ( "testing" - "github.com/mattermost/mattermost/server/public/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -134,72 +133,3 @@ func TestMySQLJSONArgs(t *testing.T) { assert.Equal(t, test.argString, argString) } } - -func TestAppendMultipleStatementsFlag(t *testing.T) { - testCases := []struct { - Scenario string - DSN string - ExpectedDSN string - }{ - { - "Should append multiStatements param to the DSN path with existing params", - "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s", - "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true", - }, - { - "Should append multiStatements param to the DSN path with no existing params", - "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost", - "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true", - }, - } - - for _, tc := range testCases { - t.Run(tc.Scenario, func(t *testing.T) { - res, err := AppendMultipleStatementsFlag(tc.DSN) - require.NoError(t, err) - assert.Equal(t, tc.ExpectedDSN, res) - }) - } -} - -func TestSanitizeDataSource(t *testing.T) { - t.Run(model.DatabaseDriverPostgres, func(t *testing.T) { - testCases := []struct { - Original string - Sanitized string - }{ - { - "postgres://mmuser:mostest@localhost/dummy?sslmode=disable", - "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", - }, - { - "postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest", - "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", - }, - } - driver := model.DatabaseDriverPostgres - for _, tc := range testCases { - out, err := SanitizeDataSource(driver, tc.Original) - require.NoError(t, err) - assert.Equal(t, tc.Sanitized, out) - } - }) - - t.Run(model.DatabaseDriverMysql, func(t *testing.T) { - testCases := []struct { - Original string - Sanitized string - }{ - { - "mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s", - "****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8", - }, - } - driver := model.DatabaseDriverMysql - for _, tc := range testCases { - out, err := SanitizeDataSource(driver, tc.Original) - require.NoError(t, err) - assert.Equal(t, tc.Sanitized, out) - } - }) -} diff --git a/server/config/database.go b/server/config/database.go index 8414144071..51a8e0ac97 100644 --- a/server/config/database.go +++ b/server/config/database.go @@ -15,6 +15,8 @@ import ( "path/filepath" "strings" + sqlUtils "github.com/mattermost/mattermost/server/public/utils/sql" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -27,8 +29,6 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" - "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" - "github.com/mattermost/morph/drivers" ms "github.com/mattermost/morph/drivers/mysql" ps "github.com/mattermost/morph/drivers/postgres" @@ -121,12 +121,12 @@ func (ds *DatabaseStore) initializeConfigurationsTable() error { var driver drivers.Driver switch ds.driverName { case model.DatabaseDriverMysql: - dataSource, rErr := sqlstore.ResetReadTimeout(ds.dataSourceName) + dataSource, rErr := sqlUtils.ResetReadTimeout(ds.dataSourceName) if rErr != nil { return fmt.Errorf("failed to reset read timeout from datasource: %w", rErr) } - dataSource, err = sqlstore.AppendMultipleStatementsFlag(dataSource) + dataSource, err = sqlUtils.AppendMultipleStatementsFlag(dataSource) if err != nil { return err } @@ -409,7 +409,7 @@ func (ds *DatabaseStore) RemoveFile(name string) error { func (ds *DatabaseStore) String() string { // This is called during the running of MM, so we expect the parsing of DSN // to be successful. - sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn) + sanitized, _ := sqlUtils.SanitizeDataSource(ds.driverName, ds.originalDsn) return sanitized } diff --git a/server/public/utils/sql/sql_utils.go b/server/public/utils/sql/sql_utils.go new file mode 100644 index 0000000000..a1029f314a --- /dev/null +++ b/server/public/utils/sql/sql_utils.go @@ -0,0 +1,126 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sql + +import ( + "context" + dbsql "database/sql" + "net/url" + "strings" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/pkg/errors" +) + +const ( + DBPingTimeoutSecs = 10 + + replicaLagPrefix = "replica-lag" +) + +// ResetReadTimeout removes the timeout constraint from the MySQL dsn. +func ResetReadTimeout(dataSource string) (string, error) { + config, err := mysql.ParseDSN(dataSource) + if err != nil { + return "", err + } + config.ReadTimeout = 0 + return config.FormatDSN(), nil +} + +// AppendMultipleStatementsFlag attached dsn parameters to MySQL dsn in order to make migrations work. +func AppendMultipleStatementsFlag(dataSource string) (string, error) { + config, err := mysql.ParseDSN(dataSource) + if err != nil { + return "", err + } + + if config.Params == nil { + config.Params = map[string]string{} + } + + config.Params["multiStatements"] = "true" + return config.FormatDSN(), nil +} + +// SetupConnection sets up the connection to the database and pings it to make sure it's alive. +// It also applies any database configuration settings that are required. +func SetupConnection(logger mlog.LoggerIFace, connType string, dataSource string, settings *model.SqlSettings, attempts int) (*dbsql.DB, error) { + db, err := dbsql.Open(*settings.DriverName, dataSource) + if err != nil { + return nil, errors.Wrap(err, "failed to open SQL connection") + } + + // At this point, we have passed sql.Open, so we deliberately ignore any errors. + sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource) + + logger = logger.With( + mlog.String("database", connType), + mlog.String("dataSource", sanitized), + ) + + for i := 0; i < attempts; i++ { + logger.Info("Pinging SQL") + ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second) + defer cancel() + err = db.PingContext(ctx) + if err != nil { + if i == attempts-1 { + return nil, err + } + logger.Error("Failed to ping DB", mlog.Int("retrying in seconds", DBPingTimeoutSecs), mlog.Err(err)) + time.Sleep(DBPingTimeoutSecs * time.Second) + continue + } + break + } + + if strings.HasPrefix(connType, replicaLagPrefix) { + // If this is a replica lag connection, we just open one connection. + // + // Arguably, if the query doesn't require a special credential, it does take up + // one extra connection from the replica DB. But falling back to the replica + // data source when the replica lag data source is null implies an ordering constraint + // which makes things brittle and is not a good design. + // If connections are an overhead, it is advised to use a connection pool. + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + } else { + db.SetMaxIdleConns(*settings.MaxIdleConns) + db.SetMaxOpenConns(*settings.MaxOpenConns) + } + db.SetConnMaxLifetime(time.Duration(*settings.ConnMaxLifetimeMilliseconds) * time.Millisecond) + db.SetConnMaxIdleTime(time.Duration(*settings.ConnMaxIdleTimeMilliseconds) * time.Millisecond) + + return db, nil +} + +func SanitizeDataSource(driverName, dataSource string) (string, error) { + switch driverName { + case model.DatabaseDriverPostgres: + u, err := url.Parse(dataSource) + if err != nil { + return "", err + } + u.User = url.UserPassword("****", "****") + params := u.Query() + params.Del("user") + params.Del("password") + u.RawQuery = params.Encode() + return u.String(), nil + case model.DatabaseDriverMysql: + cfg, err := mysql.ParseDSN(dataSource) + if err != nil { + return "", err + } + cfg.User = "****" + cfg.Passwd = "****" + return cfg.FormatDSN(), nil + default: + return "", errors.New("invalid drivername. Not postgres or mysql.") + } +} diff --git a/server/public/utils/sql/sql_utils_test.go b/server/public/utils/sql/sql_utils_test.go new file mode 100644 index 0000000000..30a5818220 --- /dev/null +++ b/server/public/utils/sql/sql_utils_test.go @@ -0,0 +1,109 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sql + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAppendMultipleStatementsFlag(t *testing.T) { + testCases := []struct { + Scenario string + DSN string + ExpectedDSN string + }{ + { + "Should append multiStatements param to the DSN path with existing params", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?writeTimeout=30s&multiStatements=true", + }, + { + "Should append multiStatements param to the DSN path with no existing params", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?multiStatements=true", + }, + } + + for _, tc := range testCases { + t.Run(tc.Scenario, func(t *testing.T) { + res, err := AppendMultipleStatementsFlag(tc.DSN) + require.NoError(t, err) + assert.Equal(t, tc.ExpectedDSN, res) + }) + } +} + +func TestResetReadTimeout(t *testing.T) { + testCases := []struct { + Scenario string + DSN string + ExpectedDSN string + }{ + { + "Should re move read timeout param from the DSN", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost?readTimeout=30s", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost", + }, + { + "Should change nothing as there is no read timeout param specified", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost", + "user:rand?&ompasswith@character@unix(/var/run/mysqld/mysqld.sock)/mattermost", + }, + } + + for _, tc := range testCases { + t.Run(tc.Scenario, func(t *testing.T) { + res, err := ResetReadTimeout(tc.DSN) + require.NoError(t, err) + assert.Equal(t, tc.ExpectedDSN, res) + }) + } +} + +func TestSanitizeDataSource(t *testing.T) { + t.Run(model.DatabaseDriverPostgres, func(t *testing.T) { + testCases := []struct { + Original string + Sanitized string + }{ + { + "postgres://mmuser:mostest@localhost/dummy?sslmode=disable", + "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", + }, + { + "postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest", + "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", + }, + } + driver := model.DatabaseDriverPostgres + for _, tc := range testCases { + out, err := SanitizeDataSource(driver, tc.Original) + require.NoError(t, err) + assert.Equal(t, tc.Sanitized, out) + } + }) + + t.Run(model.DatabaseDriverMysql, func(t *testing.T) { + testCases := []struct { + Original string + Sanitized string + }{ + { + "mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s", + "****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8", + }, + } + driver := model.DatabaseDriverMysql + for _, tc := range testCases { + out, err := SanitizeDataSource(driver, tc.Original) + require.NoError(t, err) + assert.Equal(t, tc.Sanitized, out) + } + }) +}