Unified Storage: Fixes bug with postgres connection string and adds tests (#87656)

This commit is contained in:
owensmallwood 2024-05-13 10:16:26 -06:00 committed by GitHub
parent 8c585c4a79
commit 3bf39d6d9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 37 deletions

View File

@ -0,0 +1,71 @@
package dbimpl
import (
"fmt"
"strings"
"time"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"xorm.io/xorm"
)
func getEngineMySQL(cfgSection *setting.DynamicSection) (*xorm.Engine, error) {
dbHost := cfgSection.Key("db_host").MustString("")
dbName := cfgSection.Key("db_name").MustString("")
dbUser := cfgSection.Key("db_user").MustString("")
dbPass := cfgSection.Key("db_pass").MustString("")
// TODO: support all mysql connection options
protocol := "tcp"
if strings.HasPrefix(dbHost, "/") {
protocol = "unix"
}
connectionString := connectionStringMySQL(dbUser, dbPass, protocol, dbHost, dbName)
engine, err := xorm.NewEngine("mysql", connectionString)
if err != nil {
return nil, err
}
engine.SetMaxOpenConns(0)
engine.SetMaxIdleConns(2)
engine.SetConnMaxLifetime(time.Second * time.Duration(14400))
return engine, nil
}
func getEnginePostgres(cfgSection *setting.DynamicSection) (*xorm.Engine, error) {
dbHost := cfgSection.Key("db_host").MustString("")
dbName := cfgSection.Key("db_name").MustString("")
dbUser := cfgSection.Key("db_user").MustString("")
dbPass := cfgSection.Key("db_pass").MustString("")
// TODO: support all postgres connection options
dbSslMode := cfgSection.Key("db_sslmode").MustString("disable")
addr, err := util.SplitHostPortDefault(dbHost, "127.0.0.1", "5432")
if err != nil {
return nil, fmt.Errorf("invalid host specifier '%s': %w", dbHost, err)
}
connectionString := connectionStringPostgres(dbUser, dbPass, addr.Host, addr.Port, dbName, dbSslMode)
engine, err := xorm.NewEngine("postgres", connectionString)
if err != nil {
return nil, err
}
return engine, nil
}
func connectionStringMySQL(user, password, protocol, host, dbName string) string {
return fmt.Sprintf("%s:%s@%s(%s)/%s?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", user, password, protocol, host, dbName)
}
func connectionStringPostgres(user, password, host, port, dbName, sslMode string) string {
return fmt.Sprintf(
"user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", // sslcert='%s' sslkey='%s' sslrootcert='%s'",
user, password, host, port, dbName, sslMode, // ss.dbCfg.ClientCertPath, ss.dbCfg.ClientKeyPath, ss.dbCfg.CaCertPath
)
}

View File

@ -0,0 +1,54 @@
package dbimpl
import (
"strings"
"testing"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetEnginePostgresFromConfig(t *testing.T) {
cfg := setting.NewCfg()
s, err := cfg.Raw.NewSection("entity_api")
require.NoError(t, err)
s.Key("db_type").SetValue("mysql")
s.Key("db_host").SetValue("localhost")
s.Key("db_name").SetValue("grafana")
s.Key("db_user").SetValue("user")
s.Key("db_password").SetValue("password")
engine, err := getEnginePostgres(cfg.SectionWithEnvOverrides("entity_api"))
assert.NotNil(t, engine)
assert.NoError(t, err)
assert.True(t, strings.Contains(engine.DataSourceName(), "dbname=grafana"))
}
func TestGetEngineMySQLFromConfig(t *testing.T) {
cfg := setting.NewCfg()
s, err := cfg.Raw.NewSection("entity_api")
require.NoError(t, err)
s.Key("db_type").SetValue("mysql")
s.Key("db_host").SetValue("localhost")
s.Key("db_name").SetValue("grafana")
s.Key("db_user").SetValue("user")
s.Key("db_password").SetValue("password")
engine, err := getEngineMySQL(cfg.SectionWithEnvOverrides("entity_api"))
assert.NotNil(t, engine)
assert.NoError(t, err)
}
func TestGetConnectionStrings(t *testing.T) {
t.Run("generate mysql connection string", func(t *testing.T) {
expected := "user:password@tcp(localhost)/grafana?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true"
assert.Equal(t, expected, connectionStringMySQL("user", "password", "tcp", "localhost", "grafana"))
})
t.Run("generate postgres connection string", func(t *testing.T) {
expected := "user=user password=password host=localhost port=5432 dbname=grafana sslmode=disable"
assert.Equal(t, expected, connectionStringPostgres("user", "password", "localhost", "5432", "grafana", "disable"))
})
}

View File

@ -2,8 +2,6 @@ package dbimpl
import (
"fmt"
"strings"
"time"
"github.com/dlmiddlecote/sqlstats"
"github.com/grafana/grafana/pkg/infra/db"
@ -13,7 +11,6 @@ import (
entitydb "github.com/grafana/grafana/pkg/services/store/entity/db"
"github.com/grafana/grafana/pkg/services/store/entity/db/migrations"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"xorm.io/xorm"
@ -56,26 +53,8 @@ func (db *EntityDB) GetEngine() (*xorm.Engine, error) {
// if explicit connection settings are provided, use them
if dbType != "" {
dbHost := cfgSection.Key("db_host").MustString("")
dbName := cfgSection.Key("db_name").MustString("")
dbUser := cfgSection.Key("db_user").MustString("")
dbPass := cfgSection.Key("db_pass").MustString("")
if dbType == "postgres" {
// TODO: support all postgres connection options
dbSslMode := cfgSection.Key("db_sslmode").MustString("disable")
addr, err := util.SplitHostPortDefault(dbHost, "127.0.0.1", "5432")
if err != nil {
return nil, fmt.Errorf("invalid host specifier '%s': %w", dbHost, err)
}
connectionString := fmt.Sprintf(
"user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", // sslcert='%s' sslkey='%s' sslrootcert='%s'",
dbUser, dbPass, addr.Host, addr.Port, dbName, dbSslMode, // ss.dbCfg.ClientCertPath, ss.dbCfg.ClientKeyPath, ss.dbCfg.CaCertPath
)
engine, err = xorm.NewEngine("postgres", connectionString)
engine, err = getEnginePostgres(cfgSection)
if err != nil {
return nil, err
}
@ -87,24 +66,10 @@ func (db *EntityDB) GetEngine() (*xorm.Engine, error) {
// FIXME: return nil, err
}
} else if dbType == "mysql" {
// TODO: support all mysql connection options
protocol := "tcp"
if strings.HasPrefix(dbHost, "/") {
protocol = "unix"
}
connectionString := fmt.Sprintf("%s:%s@%s(%s)/%s?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true",
dbUser, dbPass, protocol, dbHost, dbName)
engine, err = xorm.NewEngine("mysql", connectionString)
engine, err = getEngineMySQL(cfgSection)
if err != nil {
return nil, err
}
engine.SetMaxOpenConns(0)
engine.SetMaxIdleConns(2)
engine.SetConnMaxLifetime(time.Second * time.Duration(14400))
_, err = engine.Exec("SELECT 1")
if err != nil {
return nil, err