diff --git a/pkg/services/sqlstore/database_config.go b/pkg/services/sqlstore/database_config.go new file mode 100644 index 00000000000..af20267f8fa --- /dev/null +++ b/pkg/services/sqlstore/database_config.go @@ -0,0 +1,222 @@ +package sqlstore + +import ( + "errors" + "fmt" + "net/url" + "os" + "path" + "path/filepath" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/sqlstore/migrator" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" +) + +type DatabaseConfig struct { + Type string + Host string + Name string + User string + Pwd string + Path string + SslMode string + CaCertPath string + ClientKeyPath string + ClientCertPath string + ServerCertName string + ConnectionString string + IsolationLevel string + MaxOpenConn int + MaxIdleConn int + ConnMaxLifetime int + CacheMode string + WALEnabled bool + UrlQueryParams map[string][]string + SkipMigrations bool + MigrationLockAttemptTimeout int + LogQueries bool + // SQLite only + QueryRetries int + // SQLite only + TransactionRetries int +} + +func NewDatabaseConfig(cfg *setting.Cfg) (*DatabaseConfig, error) { + if cfg == nil { + return nil, errors.New("cfg cannot be nil") + } + + dbCfg := &DatabaseConfig{} + if err := dbCfg.readConfig(cfg); err != nil { + return nil, err + } + + if err := dbCfg.buildConnectionString(cfg); err != nil { + return nil, err + } + + return dbCfg, nil +} + +func (dbCfg *DatabaseConfig) readConfig(cfg *setting.Cfg) error { + sec := cfg.Raw.Section("database") + + cfgURL := sec.Key("url").String() + if len(cfgURL) != 0 { + dbURL, err := url.Parse(cfgURL) + if err != nil { + return err + } + dbCfg.Type = dbURL.Scheme + dbCfg.Host = dbURL.Host + + pathSplit := strings.Split(dbURL.Path, "/") + if len(pathSplit) > 1 { + dbCfg.Name = pathSplit[1] + } + + userInfo := dbURL.User + if userInfo != nil { + dbCfg.User = userInfo.Username() + dbCfg.Pwd, _ = userInfo.Password() + } + + dbCfg.UrlQueryParams = dbURL.Query() + } else { + dbCfg.Type = sec.Key("type").String() + dbCfg.Host = sec.Key("host").String() + dbCfg.Name = sec.Key("name").String() + dbCfg.User = sec.Key("user").String() + dbCfg.ConnectionString = sec.Key("connection_string").String() + dbCfg.Pwd = sec.Key("password").String() + } + + dbCfg.MaxOpenConn = sec.Key("max_open_conn").MustInt(0) + dbCfg.MaxIdleConn = sec.Key("max_idle_conn").MustInt(2) + dbCfg.ConnMaxLifetime = sec.Key("conn_max_lifetime").MustInt(14400) + + dbCfg.SslMode = sec.Key("ssl_mode").String() + dbCfg.CaCertPath = sec.Key("ca_cert_path").String() + dbCfg.ClientKeyPath = sec.Key("client_key_path").String() + dbCfg.ClientCertPath = sec.Key("client_cert_path").String() + dbCfg.ServerCertName = sec.Key("server_cert_name").String() + dbCfg.Path = sec.Key("path").MustString("data/grafana.db") + dbCfg.IsolationLevel = sec.Key("isolation_level").String() + + dbCfg.CacheMode = sec.Key("cache_mode").MustString("private") + dbCfg.WALEnabled = sec.Key("wal").MustBool(false) + dbCfg.SkipMigrations = sec.Key("skip_migrations").MustBool() + dbCfg.MigrationLockAttemptTimeout = sec.Key("locking_attempt_timeout_sec").MustInt() + + dbCfg.QueryRetries = sec.Key("query_retries").MustInt() + dbCfg.TransactionRetries = sec.Key("transaction_retries").MustInt(5) + + dbCfg.LogQueries = sec.Key("log_queries").MustBool(false) + + return nil +} + +func (dbCfg *DatabaseConfig) buildConnectionString(cfg *setting.Cfg) error { + if dbCfg.ConnectionString != "" { + return nil + } + + cnnstr := "" + + switch dbCfg.Type { + case migrator.MySQL: + protocol := "tcp" + if strings.HasPrefix(dbCfg.Host, "/") { + protocol = "unix" + } + + cnnstr = fmt.Sprintf("%s:%s@%s(%s)/%s?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", + dbCfg.User, dbCfg.Pwd, protocol, dbCfg.Host, dbCfg.Name) + + if dbCfg.SslMode == "true" || dbCfg.SslMode == "skip-verify" { + tlsCert, err := makeCert(dbCfg) + if err != nil { + return err + } + if err := mysql.RegisterTLSConfig("custom", tlsCert); err != nil { + return err + } + + cnnstr += "&tls=custom" + } + + if isolation := dbCfg.IsolationLevel; isolation != "" { + val := url.QueryEscape(fmt.Sprintf("'%s'", isolation)) + cnnstr += fmt.Sprintf("&transaction_isolation=%s", val) + } + + // nolint:staticcheck + if cfg.IsFeatureToggleEnabled(featuremgmt.FlagMysqlAnsiQuotes) { + cnnstr += "&sql_mode='ANSI_QUOTES'" + } + + cnnstr += buildExtraConnectionString('&', dbCfg.UrlQueryParams) + case migrator.Postgres: + addr, err := util.SplitHostPortDefault(dbCfg.Host, "127.0.0.1", "5432") + if err != nil { + return fmt.Errorf("invalid host specifier '%s': %w", dbCfg.Host, err) + } + + args := []any{dbCfg.User, addr.Host, addr.Port, dbCfg.Name, dbCfg.SslMode, dbCfg.ClientCertPath, + dbCfg.ClientKeyPath, dbCfg.CaCertPath} + for i, arg := range args { + if arg == "" { + args[i] = "''" + } + } + cnnstr = fmt.Sprintf("user=%s host=%s port=%s dbname=%s sslmode=%s sslcert=%s sslkey=%s sslrootcert=%s", args...) + if dbCfg.Pwd != "" { + cnnstr += fmt.Sprintf(" password=%s", dbCfg.Pwd) + } + + cnnstr += buildExtraConnectionString(' ', dbCfg.UrlQueryParams) + case migrator.SQLite: + // special case for tests + if !filepath.IsAbs(dbCfg.Path) { + dbCfg.Path = filepath.Join(cfg.DataPath, dbCfg.Path) + } + if err := os.MkdirAll(path.Dir(dbCfg.Path), os.ModePerm); err != nil { + return err + } + + cnnstr = fmt.Sprintf("file:%s?cache=%s&mode=rwc", dbCfg.Path, dbCfg.CacheMode) + + if dbCfg.WALEnabled { + cnnstr += "&_journal_mode=WAL" + } + + cnnstr += buildExtraConnectionString('&', dbCfg.UrlQueryParams) + default: + return fmt.Errorf("unknown database type: %s", dbCfg.Type) + } + + dbCfg.ConnectionString = cnnstr + + return nil +} + +func buildExtraConnectionString(sep rune, urlQueryParams map[string][]string) string { + if urlQueryParams == nil { + return "" + } + + var sb strings.Builder + for key, values := range urlQueryParams { + for _, value := range values { + sb.WriteRune(sep) + sb.WriteString(key) + sb.WriteRune('=') + sb.WriteString(value) + } + } + return sb.String() +} diff --git a/pkg/services/sqlstore/database_config_test.go b/pkg/services/sqlstore/database_config_test.go new file mode 100644 index 00000000000..9fba6b55a23 --- /dev/null +++ b/pkg/services/sqlstore/database_config_test.go @@ -0,0 +1,146 @@ +package sqlstore + +import ( + "errors" + "net/url" + "testing" + + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type databaseConfigTest struct { + name string + dbType string + dbHost string + dbURL string + dbUser string + dbPwd string + expConnStr string + features featuremgmt.FeatureToggles + err error +} + +var databaseConfigTestCases = []databaseConfigTest{ + { + name: "MySQL IPv4", + dbType: "mysql", + dbHost: "1.2.3.4:5678", + expConnStr: ":@tcp(1.2.3.4:5678)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", + }, + { + name: "Postgres IPv4", + dbType: "postgres", + dbHost: "1.2.3.4:5678", + expConnStr: "user='' host=1.2.3.4 port=5678 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", + }, + { + name: "Postgres IPv4 (Default Port)", + dbType: "postgres", + dbHost: "1.2.3.4", + expConnStr: "user='' host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", + }, + { + name: "Postgres username and password", + dbType: "postgres", + dbHost: "1.2.3.4", + dbUser: "grafana", + dbPwd: "password", + expConnStr: "user=grafana host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert='' password=password", + }, + { + name: "Postgres username no password", + dbType: "postgres", + dbHost: "1.2.3.4", + dbUser: "grafana", + dbPwd: "", + expConnStr: "user=grafana host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", + }, + { + name: "MySQL IPv4 (Default Port)", + dbType: "mysql", + dbHost: "1.2.3.4", + expConnStr: ":@tcp(1.2.3.4)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", + }, + { + name: "MySQL IPv6", + dbType: "mysql", + dbHost: "[fe80::24e8:31b2:91df:b177]:1234", + expConnStr: ":@tcp([fe80::24e8:31b2:91df:b177]:1234)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", + }, + { + name: "Postgres IPv6", + dbType: "postgres", + dbHost: "[fe80::24e8:31b2:91df:b177]:1234", + expConnStr: "user='' host=fe80::24e8:31b2:91df:b177 port=1234 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", + }, + { + name: "MySQL IPv6 (Default Port)", + dbType: "mysql", + dbHost: "[::1]", + expConnStr: ":@tcp([::1])/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", + }, + { + name: "Postgres IPv6 (Default Port)", + dbType: "postgres", + dbHost: "[::1]", + expConnStr: "user='' host=::1 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", + }, + { + name: "Invalid database URL", + dbURL: "://invalid.com/", + err: &url.Error{Op: "parse", URL: "://invalid.com/", Err: errors.New("missing protocol scheme")}, + }, + { + name: "MySQL with ANSI_QUOTES mode", + dbType: "mysql", + dbHost: "[::1]", + features: featuremgmt.WithFeatures(featuremgmt.FlagMysqlAnsiQuotes), + expConnStr: ":@tcp([::1])/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true&sql_mode='ANSI_QUOTES'", + }, +} + +func TestIntegrationSQLConnectionString(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + for _, testCase := range databaseConfigTestCases { + t.Run(testCase.name, func(t *testing.T) { + cfg := makeDatabaseTestConfig(t, testCase) + dbCfg, err := NewDatabaseConfig(cfg) + require.Equal(t, testCase.err, err) + if testCase.expConnStr != "" { + assert.Equal(t, testCase.expConnStr, dbCfg.ConnectionString) + } + }) + } +} + +func makeDatabaseTestConfig(t *testing.T, tc databaseConfigTest) *setting.Cfg { + t.Helper() + + if tc.features == nil { + tc.features = featuremgmt.WithFeatures() + } + // nolint:staticcheck + cfg := setting.NewCfgWithFeatures(tc.features.IsEnabledGlobally) + + sec, err := cfg.Raw.NewSection("database") + require.NoError(t, err) + _, err = sec.NewKey("type", tc.dbType) + require.NoError(t, err) + _, err = sec.NewKey("host", tc.dbHost) + require.NoError(t, err) + _, err = sec.NewKey("url", tc.dbURL) + require.NoError(t, err) + _, err = sec.NewKey("user", tc.dbUser) + require.NoError(t, err) + _, err = sec.NewKey("name", "test_db") + require.NoError(t, err) + _, err = sec.NewKey("password", tc.dbPwd) + require.NoError(t, err) + + return cfg +} diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index 04356f23ec4..faf272ce9f3 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -4,10 +4,7 @@ import ( "context" "errors" "fmt" - "net/url" "os" - "path" - "path/filepath" "strings" "sync" "time" @@ -34,7 +31,6 @@ import ( "github.com/grafana/grafana/pkg/services/stats" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" ) // ContextSessionKey is used as key to save values in `context.Context` @@ -45,7 +41,7 @@ type SQLStore struct { sqlxsession *session.SessionDB bus bus.Bus - dbCfg DatabaseConfig + dbCfg *DatabaseConfig engine *xorm.Engine log log.Logger Dialect migrator.Dialect @@ -160,18 +156,6 @@ func (ss *SQLStore) Reset() error { return ss.ensureMainOrgAndAdminUser(false) } -// TestReset resets database state. If default org and user creation is enabled, -// it will be ensured they exist in the database. TestReset() is more permissive -// than Reset in that it will create the user and org whether or not there are -// already users in the database. -func (ss *SQLStore) TestReset() error { - if ss.skipEnsureDefaultOrgAndUser { - return nil - } - - return ss.ensureMainOrgAndAdminUser(true) -} - // Quote quotes the value in the used SQL dialect func (ss *SQLStore) Quote(value string) string { return ss.engine.Quote(value) @@ -248,110 +232,6 @@ func (ss *SQLStore) ensureMainOrgAndAdminUser(test bool) error { return err } -func (ss *SQLStore) buildExtraConnectionString(sep rune) string { - if ss.dbCfg.UrlQueryParams == nil { - return "" - } - - var sb strings.Builder - for key, values := range ss.dbCfg.UrlQueryParams { - for _, value := range values { - sb.WriteRune(sep) - sb.WriteString(key) - sb.WriteRune('=') - sb.WriteString(value) - } - } - return sb.String() -} - -func (ss *SQLStore) buildConnectionString() (string, error) { - if err := ss.readConfig(); err != nil { - return "", err - } - - cnnstr := ss.dbCfg.ConnectionString - - // special case used by integration tests - if cnnstr != "" { - return cnnstr, nil - } - - switch ss.dbCfg.Type { - case migrator.MySQL: - protocol := "tcp" - if strings.HasPrefix(ss.dbCfg.Host, "/") { - protocol = "unix" - } - - cnnstr = fmt.Sprintf("%s:%s@%s(%s)/%s?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", - ss.dbCfg.User, ss.dbCfg.Pwd, protocol, ss.dbCfg.Host, ss.dbCfg.Name) - - if ss.dbCfg.SslMode == "true" || ss.dbCfg.SslMode == "skip-verify" { - tlsCert, err := makeCert(ss.dbCfg) - if err != nil { - return "", err - } - if err := mysql.RegisterTLSConfig("custom", tlsCert); err != nil { - return "", err - } - - cnnstr += "&tls=custom" - } - - if isolation := ss.dbCfg.IsolationLevel; isolation != "" { - val := url.QueryEscape(fmt.Sprintf("'%s'", isolation)) - cnnstr += fmt.Sprintf("&transaction_isolation=%s", val) - } - - // nolint:staticcheck - if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagMysqlAnsiQuotes) { - cnnstr += "&sql_mode='ANSI_QUOTES'" - } - - cnnstr += ss.buildExtraConnectionString('&') - case migrator.Postgres: - addr, err := util.SplitHostPortDefault(ss.dbCfg.Host, "127.0.0.1", "5432") - if err != nil { - return "", fmt.Errorf("invalid host specifier '%s': %w", ss.dbCfg.Host, err) - } - - args := []any{ss.dbCfg.User, addr.Host, addr.Port, ss.dbCfg.Name, ss.dbCfg.SslMode, ss.dbCfg.ClientCertPath, - ss.dbCfg.ClientKeyPath, ss.dbCfg.CaCertPath} - for i, arg := range args { - if arg == "" { - args[i] = "''" - } - } - cnnstr = fmt.Sprintf("user=%s host=%s port=%s dbname=%s sslmode=%s sslcert=%s sslkey=%s sslrootcert=%s", args...) - if ss.dbCfg.Pwd != "" { - cnnstr += fmt.Sprintf(" password=%s", ss.dbCfg.Pwd) - } - - cnnstr += ss.buildExtraConnectionString(' ') - case migrator.SQLite: - // special case for tests - if !filepath.IsAbs(ss.dbCfg.Path) { - ss.dbCfg.Path = filepath.Join(ss.Cfg.DataPath, ss.dbCfg.Path) - } - if err := os.MkdirAll(path.Dir(ss.dbCfg.Path), os.ModePerm); err != nil { - return "", err - } - - cnnstr = fmt.Sprintf("file:%s?cache=%s&mode=rwc", ss.dbCfg.Path, ss.dbCfg.CacheMode) - - if ss.dbCfg.WALEnabled { - cnnstr += "&_journal_mode=WAL" - } - - cnnstr += ss.buildExtraConnectionString('&') - default: - return "", fmt.Errorf("unknown database type: %s", ss.dbCfg.Type) - } - - return cnnstr, nil -} - // initEngine initializes ss.engine. func (ss *SQLStore) initEngine(engine *xorm.Engine) error { if ss.engine != nil { @@ -359,18 +239,20 @@ func (ss *SQLStore) initEngine(engine *xorm.Engine) error { return nil } - connectionString, err := ss.buildConnectionString() + dbCfg, err := NewDatabaseConfig(ss.Cfg) if err != nil { return err } + ss.dbCfg = dbCfg + if ss.Cfg.DatabaseInstrumentQueries { ss.dbCfg.Type = WrapDatabaseDriverWithHooks(ss.dbCfg.Type, ss.tracer) } ss.log.Info("Connecting to DB", "dbtype", ss.dbCfg.Type) - if ss.dbCfg.Type == migrator.SQLite && strings.HasPrefix(connectionString, "file:") && - !strings.HasPrefix(connectionString, "file::memory:") { + if ss.dbCfg.Type == migrator.SQLite && strings.HasPrefix(ss.dbCfg.ConnectionString, "file:") && + !strings.HasPrefix(ss.dbCfg.ConnectionString, "file::memory:") { exists, err := fs.Exists(ss.dbCfg.Path) if err != nil { return fmt.Errorf("can't check for existence of %q: %w", ss.dbCfg.Path, err) @@ -400,14 +282,14 @@ func (ss *SQLStore) initEngine(engine *xorm.Engine) error { } if engine == nil { var err error - engine, err = xorm.NewEngine(ss.dbCfg.Type, connectionString) + engine, err = xorm.NewEngine(ss.dbCfg.Type, ss.dbCfg.ConnectionString) if err != nil { return err } // Only for MySQL or MariaDB, verify we can connect with the current connection string's system var for transaction isolation. // If not, create a new engine with a compatible connection string. if ss.dbCfg.Type == migrator.MySQL { - engine, err = ss.ensureTransactionIsolationCompatibility(engine, connectionString) + engine, err = ss.ensureTransactionIsolationCompatibility(engine, ss.dbCfg.ConnectionString) if err != nil { return err } @@ -459,62 +341,6 @@ func (ss *SQLStore) ensureTransactionIsolationCompatibility(engine *xorm.Engine, return engine, nil } -// readConfig initializes the SQLStore from its configuration. -func (ss *SQLStore) readConfig() error { - sec := ss.Cfg.Raw.Section("database") - - cfgURL := sec.Key("url").String() - if len(cfgURL) != 0 { - dbURL, err := url.Parse(cfgURL) - if err != nil { - return err - } - ss.dbCfg.Type = dbURL.Scheme - ss.dbCfg.Host = dbURL.Host - - pathSplit := strings.Split(dbURL.Path, "/") - if len(pathSplit) > 1 { - ss.dbCfg.Name = pathSplit[1] - } - - userInfo := dbURL.User - if userInfo != nil { - ss.dbCfg.User = userInfo.Username() - ss.dbCfg.Pwd, _ = userInfo.Password() - } - - ss.dbCfg.UrlQueryParams = dbURL.Query() - } else { - ss.dbCfg.Type = sec.Key("type").String() - ss.dbCfg.Host = sec.Key("host").String() - ss.dbCfg.Name = sec.Key("name").String() - ss.dbCfg.User = sec.Key("user").String() - ss.dbCfg.ConnectionString = sec.Key("connection_string").String() - ss.dbCfg.Pwd = sec.Key("password").String() - } - - ss.dbCfg.MaxOpenConn = sec.Key("max_open_conn").MustInt(0) - ss.dbCfg.MaxIdleConn = sec.Key("max_idle_conn").MustInt(2) - ss.dbCfg.ConnMaxLifetime = sec.Key("conn_max_lifetime").MustInt(14400) - - ss.dbCfg.SslMode = sec.Key("ssl_mode").String() - ss.dbCfg.CaCertPath = sec.Key("ca_cert_path").String() - ss.dbCfg.ClientKeyPath = sec.Key("client_key_path").String() - ss.dbCfg.ClientCertPath = sec.Key("client_cert_path").String() - ss.dbCfg.ServerCertName = sec.Key("server_cert_name").String() - ss.dbCfg.Path = sec.Key("path").MustString("data/grafana.db") - ss.dbCfg.IsolationLevel = sec.Key("isolation_level").String() - - ss.dbCfg.CacheMode = sec.Key("cache_mode").MustString("private") - ss.dbCfg.WALEnabled = sec.Key("wal").MustBool(false) - ss.dbCfg.SkipMigrations = sec.Key("skip_migrations").MustBool() - ss.dbCfg.MigrationLockAttemptTimeout = sec.Key("locking_attempt_timeout_sec").MustInt() - - ss.dbCfg.QueryRetries = sec.Key("query_retries").MustInt() - ss.dbCfg.TransactionRetries = sec.Key("transaction_retries").MustInt(5) - return nil -} - func (ss *SQLStore) GetMigrationLockAttemptTimeout() int { return ss.dbCfg.MigrationLockAttemptTimeout } @@ -747,55 +573,3 @@ func initTestDB(testCfg *setting.Cfg, migration registry.DatabaseMigrator, opts return testSQLStore, nil } - -func IsTestDbMySQL() bool { - if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { - return db == migrator.MySQL - } - - return false -} - -func IsTestDbPostgres() bool { - if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { - return db == migrator.Postgres - } - - return false -} - -func IsTestDBMSSQL() bool { - if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { - return db == migrator.MSSQL - } - - return false -} - -type DatabaseConfig struct { - Type string - Host string - Name string - User string - Pwd string - Path string - SslMode string - CaCertPath string - ClientKeyPath string - ClientCertPath string - ServerCertName string - ConnectionString string - IsolationLevel string - MaxOpenConn int - MaxIdleConn int - ConnMaxLifetime int - CacheMode string - WALEnabled bool - UrlQueryParams map[string][]string - SkipMigrations bool - MigrationLockAttemptTimeout int - // SQLite only - QueryRetries int - // SQLite only - TransactionRetries int -} diff --git a/pkg/services/sqlstore/sqlstore_test.go b/pkg/services/sqlstore/sqlstore_test.go index 83154c62eef..0441a0ff368 100644 --- a/pkg/services/sqlstore/sqlstore_test.go +++ b/pkg/services/sqlstore/sqlstore_test.go @@ -2,126 +2,15 @@ package sqlstore import ( "context" - "errors" - "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/setting" ) -type sqlStoreTest struct { - name string - dbType string - dbHost string - dbURL string - dbUser string - dbPwd string - expConnStr string - features featuremgmt.FeatureToggles - err error -} - -var sqlStoreTestCases = []sqlStoreTest{ - { - name: "MySQL IPv4", - dbType: "mysql", - dbHost: "1.2.3.4:5678", - expConnStr: ":@tcp(1.2.3.4:5678)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", - }, - { - name: "Postgres IPv4", - dbType: "postgres", - dbHost: "1.2.3.4:5678", - expConnStr: "user='' host=1.2.3.4 port=5678 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", - }, - { - name: "Postgres IPv4 (Default Port)", - dbType: "postgres", - dbHost: "1.2.3.4", - expConnStr: "user='' host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", - }, - { - name: "Postgres username and password", - dbType: "postgres", - dbHost: "1.2.3.4", - dbUser: "grafana", - dbPwd: "password", - expConnStr: "user=grafana host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert='' password=password", - }, - { - name: "Postgres username no password", - dbType: "postgres", - dbHost: "1.2.3.4", - dbUser: "grafana", - dbPwd: "", - expConnStr: "user=grafana host=1.2.3.4 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", - }, - { - name: "MySQL IPv4 (Default Port)", - dbType: "mysql", - dbHost: "1.2.3.4", - expConnStr: ":@tcp(1.2.3.4)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", - }, - { - name: "MySQL IPv6", - dbType: "mysql", - dbHost: "[fe80::24e8:31b2:91df:b177]:1234", - expConnStr: ":@tcp([fe80::24e8:31b2:91df:b177]:1234)/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", - }, - { - name: "Postgres IPv6", - dbType: "postgres", - dbHost: "[fe80::24e8:31b2:91df:b177]:1234", - expConnStr: "user='' host=fe80::24e8:31b2:91df:b177 port=1234 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", - }, - { - name: "MySQL IPv6 (Default Port)", - dbType: "mysql", - dbHost: "[::1]", - expConnStr: ":@tcp([::1])/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", - }, - { - name: "Postgres IPv6 (Default Port)", - dbType: "postgres", - dbHost: "[::1]", - expConnStr: "user='' host=::1 port=5432 dbname=test_db sslmode='' sslcert='' sslkey='' sslrootcert=''", - }, - { - name: "Invalid database URL", - dbURL: "://invalid.com/", - err: &url.Error{Op: "parse", URL: "://invalid.com/", Err: errors.New("missing protocol scheme")}, - }, - { - name: "MySQL with ANSI_QUOTES mode", - dbType: "mysql", - dbHost: "[::1]", - features: featuremgmt.WithFeatures(featuremgmt.FlagMysqlAnsiQuotes), - expConnStr: ":@tcp([::1])/test_db?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true&sql_mode='ANSI_QUOTES'", - }, -} - -func TestIntegrationSQLConnectionString(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test") - } - for _, testCase := range sqlStoreTestCases { - t.Run(testCase.name, func(t *testing.T) { - sqlstore := &SQLStore{} - sqlstore.Cfg = makeSQLStoreTestConfig(t, testCase) - connStr, err := sqlstore.buildConnectionString() - require.Equal(t, testCase.err, err) - - assert.Equal(t, testCase.expConnStr, connStr) - }) - } -} - func TestIntegrationIsUniqueConstraintViolation(t *testing.T) { store := InitTestDB(t) @@ -169,30 +58,3 @@ func TestIntegrationIsUniqueConstraintViolation(t *testing.T) { }) } } - -func makeSQLStoreTestConfig(t *testing.T, tc sqlStoreTest) *setting.Cfg { - t.Helper() - - if tc.features == nil { - tc.features = featuremgmt.WithFeatures() - } - // nolint:staticcheck - cfg := setting.NewCfgWithFeatures(tc.features.IsEnabledGlobally) - - sec, err := cfg.Raw.NewSection("database") - require.NoError(t, err) - _, err = sec.NewKey("type", tc.dbType) - require.NoError(t, err) - _, err = sec.NewKey("host", tc.dbHost) - require.NoError(t, err) - _, err = sec.NewKey("url", tc.dbURL) - require.NoError(t, err) - _, err = sec.NewKey("user", tc.dbUser) - require.NoError(t, err) - _, err = sec.NewKey("name", "test_db") - require.NoError(t, err) - _, err = sec.NewKey("password", tc.dbPwd) - require.NoError(t, err) - - return cfg -} diff --git a/pkg/services/sqlstore/tls_mysql.go b/pkg/services/sqlstore/tls_mysql.go index bb7baf2f4e6..d54b065da1e 100644 --- a/pkg/services/sqlstore/tls_mysql.go +++ b/pkg/services/sqlstore/tls_mysql.go @@ -11,7 +11,7 @@ import ( var tlslog = log.New("tls_mysql") -func makeCert(config DatabaseConfig) (*tls.Config, error) { +func makeCert(config *DatabaseConfig) (*tls.Config, error) { rootCertPool := x509.NewCertPool() pem, err := os.ReadFile(config.CaCertPath) if err != nil {