fix(unified-storage): proper setup TLS in new db_engine for MySQL (#100686)

This commit is contained in:
Jean-Philippe Quéméner 2025-02-14 16:23:25 +01:00 committed by GitHub
parent 4d7b9a3c77
commit c522a5b13b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 353 additions and 131 deletions

View File

@ -1,118 +0,0 @@
package dbimpl
import (
"testing"
"github.com/stretchr/testify/assert"
)
func newValidMySQLGetter(withKeyPrefix bool) confGetter {
var prefix string
if withKeyPrefix {
prefix = "db_"
}
return newTestConfGetter(map[string]string{
prefix + "type": dbTypeMySQL,
prefix + "host": "/var/run/mysql.socket",
prefix + "name": "grafana",
prefix + "user": "user",
prefix + "password": "password",
}, prefix)
}
func TestGetEngineMySQLFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path - with key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEngineMySQL(newValidMySQLGetter(true))
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("happy path - without key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEngineMySQL(newValidMySQLGetter(false))
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypeMySQL,
"db_host": "/var/run/mysql.socket",
"db_name": string(invalidUTF8ByteSequence),
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEngineMySQL(getter)
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, errInvalidUTF8Sequence)
})
}
func newValidPostgresGetter(withKeyPrefix bool) confGetter {
var prefix string
if withKeyPrefix {
prefix = "db_"
}
return newTestConfGetter(map[string]string{
prefix + "type": dbTypePostgres,
prefix + "host": "localhost",
prefix + "name": "grafana",
prefix + "user": "user",
prefix + "password": "password",
}, prefix)
}
func TestGetEnginePostgresFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path - with key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEnginePostgres(newValidPostgresGetter(true))
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("happy path - without key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEnginePostgres(newValidPostgresGetter(false))
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypePostgres,
"db_host": string(invalidUTF8ByteSequence),
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEnginePostgres(getter)
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, errInvalidUTF8Sequence)
})
t.Run("invalid hostport", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypePostgres,
"db_host": "1:1:1",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEnginePostgres(getter)
assert.Nil(t, engine)
assert.Error(t, err)
})
}

View File

@ -7,11 +7,17 @@ import (
"time"
"github.com/go-sql-driver/mysql"
"github.com/grafana/dskit/crypto/tls"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
// tlsConfigName is the name of the TLS config that we register with the MySQL
// driver.
const tlsConfigName = "db_engine_tls"
func getEngineMySQL(getter confGetter) (*xorm.Engine, error) {
config := mysql.NewConfig()
config.User = getter.String("user")
@ -25,29 +31,22 @@ func getEngineMySQL(getter confGetter) (*xorm.Engine, error) {
// See: https://dev.mysql.com/doc/refman/en/sql-mode.html
"@@SESSION.sql_mode": "ANSI",
}
sslMode := getter.String("ssl_mode")
if sslMode == "true" || sslMode == "skip-verify" {
config.Params["tls"] = "preferred"
}
tls := getter.String("tls")
if tls != "" {
config.Params["tls"] = tls
}
config.Collation = "utf8mb4_unicode_ci"
config.Loc = time.UTC
config.AllowNativePasswords = true
config.ClientFoundRows = true
config.ParseTime = true
// Setup TLS for the database connection if configured.
if err := configureTLS(getter, config); err != nil {
return nil, fmt.Errorf("failed to configure TLS: %w", err)
}
// allow executing multiple SQL statements in a single roundtrip, and also
// enable executing the CALL statement to run stored procedures that execute
// multiple SQL statements.
//config.MultiStatements = true
// TODO: do we want to support these?
// config.ServerPubKey = getter.String("server_pub_key")
// config.TLSConfig = getter.String("tls_config_name")
if err := getter.Err(); err != nil {
return nil, fmt.Errorf("config error: %w", err)
}
@ -56,7 +55,6 @@ func getEngineMySQL(getter confGetter) (*xorm.Engine, error) {
config.Net = "unix"
}
// FIXME: get rid of xorm
engine, err := xorm.NewEngine(db.DriverMySQL, config.FormatDSN())
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
@ -69,6 +67,47 @@ func getEngineMySQL(getter confGetter) (*xorm.Engine, error) {
return engine, nil
}
func configureTLS(getter confGetter, config *mysql.Config) error {
sslMode := getter.String("ssl_mode")
if sslMode == "true" || sslMode == "skip-verify" {
tlsCfg := tls.ClientConfig{
CAPath: getter.String("ca_cert_path"),
CertPath: getter.String("client_cert_path"),
KeyPath: getter.String("client_key_path"),
ServerName: getter.String("server_cert_name"),
}
rawTLSCfg, err := tlsCfg.GetTLSConfig()
if err != nil {
return fmt.Errorf("failed to get TLS config for mysql: %w", err)
}
if sslMode == "skip-verify" {
rawTLSCfg.InsecureSkipVerify = true
}
if err := mysql.RegisterTLSConfig(tlsConfigName, rawTLSCfg); err != nil {
return fmt.Errorf("failed to register TLS config for mysql: %w", err)
}
config.TLSConfig = tlsConfigName
}
// If the TLS mode is set in the database config, we need to set it here.
if tls := getter.String("tls"); tls != "" {
// If the user has provided TLS certs, we don't want to use the tls=<value>, as
// they would override the TLS config that we set above. They both use the same
// parameter, so we need to check for that.
if sslMode == "true" {
return fmt.Errorf("cannot provide tls certs and tls=<value> at the same time")
}
config.Params["tls"] = tls
}
return nil
}
func getEnginePostgres(getter confGetter) (*xorm.Engine, error) {
dsnKV := map[string]string{
"user": getter.String("user"),

View File

@ -0,0 +1,301 @@
package dbimpl
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func newValidMySQLGetter(withKeyPrefix bool) confGetter {
var prefix string
if withKeyPrefix {
prefix = "db_"
}
return newTestConfGetter(map[string]string{
prefix + "type": dbTypeMySQL,
prefix + "host": "/var/run/mysql.socket",
prefix + "name": "grafana",
prefix + "user": "user",
prefix + "password": "password",
}, prefix)
}
func TestGetEngineMySQLFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path - with key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEngineMySQL(newValidMySQLGetter(true))
require.NotNil(t, engine)
require.NoError(t, err)
})
t.Run("happy path - without key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEngineMySQL(newValidMySQLGetter(false))
require.NotNil(t, engine)
require.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypeMySQL,
"db_host": "/var/run/mysql.socket",
"db_name": string(invalidUTF8ByteSequence),
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEngineMySQL(getter)
require.Nil(t, engine)
require.Error(t, err)
require.ErrorIs(t, err, errInvalidUTF8Sequence)
})
}
func newValidPostgresGetter(withKeyPrefix bool) confGetter {
var prefix string
if withKeyPrefix {
prefix = "db_"
}
return newTestConfGetter(map[string]string{
prefix + "type": dbTypePostgres,
prefix + "host": "localhost",
prefix + "name": "grafana",
prefix + "user": "user",
prefix + "password": "password",
}, prefix)
}
func TestGetEnginePostgresFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path - with key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEnginePostgres(newValidPostgresGetter(true))
require.NotNil(t, engine)
require.NoError(t, err)
})
t.Run("happy path - without key prefix", func(t *testing.T) {
t.Parallel()
engine, err := getEnginePostgres(newValidPostgresGetter(false))
require.NotNil(t, engine)
require.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypePostgres,
"db_host": string(invalidUTF8ByteSequence),
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEnginePostgres(getter)
require.Nil(t, engine)
require.Error(t, err)
})
t.Run("invalid hostport", func(t *testing.T) {
t.Parallel()
getter := newTestConfGetter(map[string]string{
"db_type": dbTypePostgres,
"db_host": "1:1:1",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
}, "db_")
engine, err := getEnginePostgres(getter)
require.Nil(t, engine)
require.Error(t, err)
})
}
func TestGetEngineMySQLTLS(t *testing.T) {
certs := generateTestCerts(t)
tests := []struct {
name string
config map[string]string
shouldErr bool
}{
{
name: "with TLS disabled",
config: map[string]string{
"type": "mysql",
"user": "user",
"pass": "pass",
"host": "localhost",
"name": "dbname",
"ssl_mode": "disable",
},
},
{
name: "with TLS skip-verify",
config: map[string]string{
"type": "mysql",
"user": "user",
"pass": "pass",
"host": "localhost",
"name": "dbname",
"ssl_mode": "skip-verify",
},
},
{
name: "with valid TLS certificates",
config: map[string]string{
"type": "mysql",
"user": "user",
"pass": "pass",
"host": "localhost",
"name": "dbname",
"ssl_mode": "true",
"ca_cert_path": certs.caFile,
"client_cert_path": certs.certFile,
"client_key_path": certs.keyFile,
"server_cert_name": "mysql.example.com",
},
},
{
name: "with invalid cert paths",
config: map[string]string{
"type": "mysql",
"user": "user",
"pass": "pass",
"host": "localhost",
"name": "dbname",
"ssl_mode": "true",
"ca_cert_path": "nonexistent/ca.pem",
"client_cert_path": "nonexistent/client-cert.pem",
"client_key_path": "nonexistent/client-key.pem",
"server_cert_name": "mysql.example.com",
},
shouldErr: true,
},
{
name: "with TLS certs and tls parameter",
config: map[string]string{
"type": "mysql",
"user": "user",
"pass": "pass",
"host": "localhost",
"name": "dbname",
"ssl_mode": "true",
"ca_cert_path": certs.caFile,
"client_cert_path": certs.certFile,
"client_key_path": certs.keyFile,
"server_cert_name": "mysql.example.com",
"tls": "preferred",
},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
getter := newTestConfGetter(tt.config, "")
engine, err := getEngineMySQL(getter)
if tt.shouldErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, engine)
})
}
}
type testCerts struct {
caFile string
certFile string
keyFile string
}
func generateTestCerts(t *testing.T) testCerts {
t.Helper()
tempDir := t.TempDir()
// Generate CA private key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// Generate CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
require.NoError(t, err)
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
client := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
CommonName: "Test Client",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
SubjectKeyId: []byte{1, 2, 3, 4, 5},
}
clientBytes, err := x509.CreateCertificate(rand.Reader, client, ca, &clientKey.PublicKey, caKey)
require.NoError(t, err)
// Write certificates and keys to temporary files
caFile := filepath.Join(tempDir, "ca.pem")
certFile := filepath.Join(tempDir, "cert.pem")
keyFile := filepath.Join(tempDir, "key.pem")
writePEMFile(t, caFile, "CERTIFICATE", caBytes)
writePEMFile(t, certFile, "CERTIFICATE", clientBytes)
writePEMFile(t, keyFile, "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(clientKey))
return testCerts{
caFile: caFile,
certFile: certFile,
keyFile: keyFile,
}
}
func writePEMFile(t *testing.T, filename string, blockType string, bytes []byte) {
t.Helper()
//nolint:gosec
file, err := os.Create(filename)
require.NoError(t, err)
//nolint:errcheck
defer file.Close()
err = pem.Encode(file, &pem.Block{
Type: blockType,
Bytes: bytes,
})
require.NoError(t, err)
}