diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index 032554daebe..bd72e48158b 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -23,12 +23,13 @@ import ( _ "github.com/mattn/go-sqlite3" ) -type MySQLConfig struct { - SslMode string - CaCertPath string - ClientKeyPath string - ClientCertPath string - ServerCertName string + +type DatabaseConfig struct { + Type, Host, Name, User, Pwd, Path, SslMode string + CaCertPath string + ClientKeyPath string + ClientCertPath string + ServerCertName string } var ( @@ -37,11 +38,8 @@ var ( HasEngine bool - DbCfg struct { - Type, Host, Name, User, Pwd, Path, SslMode string - } + DbCfg DatabaseConfig - mysqlConfig MySQLConfig UseSQLite3 bool sqlog log.Logger = log.New("sqlstore") ) @@ -118,8 +116,8 @@ func getEngine() (*xorm.Engine, error) { cnnstr = fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8", DbCfg.User, DbCfg.Pwd, protocol, DbCfg.Host, DbCfg.Name) - if mysqlConfig.SslMode == "true" || mysqlConfig.SslMode == "skip-verify" { - tlsCert, err := makeCert("custom", mysqlConfig) + if DbCfg.SslMode == "true" || DbCfg.SslMode == "skip-verify" { + tlsCert, err := makeCert("custom", DbCfg) if err != nil { return nil, err } @@ -141,7 +139,7 @@ func getEngine() (*xorm.Engine, error) { if DbCfg.User == "" { DbCfg.User = "''" } - cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode) + cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s sslcert=%s sslkey=%s sslrootcert=%s", DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode, DbCfg.ClientCertPath, DbCfg.ClientKeyPath, DbCfg.CaCertPath) case "sqlite3": if !filepath.IsAbs(DbCfg.Path) { DbCfg.Path = filepath.Join(setting.DataPath, DbCfg.Path) @@ -189,13 +187,9 @@ func LoadConfig() { UseSQLite3 = true } DbCfg.SslMode = sec.Key("ssl_mode").String() + DbCfg.CaCertPath = sec.Key("ca_cert_path").String() + DbCfg.ClientKeyPath = sec.Key("client_key_path").String() + DbCfg.ClientCertPath = sec.Key("client_cert_path").String() + DbCfg.ServerCertName = sec.Key("server_cert_name").String() DbCfg.Path = sec.Key("path").MustString("data/grafana.db") - - if DbCfg.Type == "mysql" { - mysqlConfig.SslMode = DbCfg.SslMode - mysqlConfig.CaCertPath = sec.Key("ca_cert_path").String() - mysqlConfig.ClientKeyPath = sec.Key("client_key_path").String() - mysqlConfig.ClientCertPath = sec.Key("client_cert_path").String() - mysqlConfig.ServerCertName = sec.Key("server_cert_name").String() - } } diff --git a/pkg/services/sqlstore/tls_mysql.go b/pkg/services/sqlstore/tls_mysql.go index fb55eb401c9..3c9475e19a0 100644 --- a/pkg/services/sqlstore/tls_mysql.go +++ b/pkg/services/sqlstore/tls_mysql.go @@ -7,7 +7,7 @@ import ( "io/ioutil" ) -func makeCert(tlsPoolName string, config MySQLConfig) (*tls.Config, error) { +func makeCert(tlsPoolName string, config DatabaseConfig) (*tls.Config, error) { rootCertPool := x509.NewCertPool() pem, err := ioutil.ReadFile(config.CaCertPath) if err != nil {