Parse database host correctly when using IPv6

This commit is contained in:
Vlad Ellis 2019-01-25 21:56:19 +00:00
parent c7b556c0e4
commit c208186f26
5 changed files with 182 additions and 15 deletions

View File

@ -21,6 +21,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/sqlutil"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/go-sql-driver/mysql"
"github.com/go-xorm/xorm"
@ -222,13 +223,9 @@ func (ss *SqlStore) buildConnectionString() (string, error) {
cnnstr += "&tls=custom"
}
case migrator.POSTGRES:
var host, port = "127.0.0.1", "5432"
fields := strings.Split(ss.dbCfg.Host, ":")
if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 {
host = fields[0]
}
if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
port = fields[1]
host, port, err := util.SplitIpPort(ss.dbCfg.Host, "5432")
if err != nil {
return "", err
}
if ss.dbCfg.Pwd == "" {
ss.dbCfg.Pwd = "''"

View File

@ -0,0 +1,101 @@
package sqlstore
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
"github.com/grafana/grafana/pkg/setting"
)
type sqlStoreTest struct {
name string
dbType string
dbHost string
connStrValues []string
}
var sqlStoreTestCases = []sqlStoreTest {
sqlStoreTest {
name: "MySQL IPv4",
dbType: "mysql",
dbHost: "1.2.3.4:5678",
connStrValues: []string {"tcp(1.2.3.4:5678)"},
},
sqlStoreTest {
name: "Postgres IPv4",
dbType: "postgres",
dbHost: "1.2.3.4:5678",
connStrValues: []string {"host=1.2.3.4", "port=5678"},
},
sqlStoreTest {
name: "Postgres IPv4 (Default Port)",
dbType: "postgres",
dbHost: "1.2.3.4",
connStrValues: []string {"host=1.2.3.4", "port=5432"},
},
sqlStoreTest {
name: "MySQL IPv4 (Default Port)",
dbType: "mysql",
dbHost: "1.2.3.4",
connStrValues: []string {"tcp(1.2.3.4)"},
},
sqlStoreTest {
name: "MySQL IPv6",
dbType: "mysql",
dbHost: "[fe80::24e8:31b2:91df:b177]:1234",
connStrValues: []string {"tcp([fe80::24e8:31b2:91df:b177]:1234)"},
},
sqlStoreTest {
name: "Postgres IPv6",
dbType: "postgres",
dbHost: "[fe80::24e8:31b2:91df:b177]:1234",
connStrValues: []string {"host=fe80::24e8:31b2:91df:b177", "port=1234"},
},
sqlStoreTest {
name: "MySQL IPv6 (Default Port)",
dbType: "mysql",
dbHost: "::1",
connStrValues: []string {"tcp(::1)"},
},
sqlStoreTest {
name: "Postgres IPv6 (Default Port)",
dbType: "postgres",
dbHost: "::1",
connStrValues: []string {"host=::1", "port=5432"},
},
}
func TestSqlConnectionString(t *testing.T) {
Convey("Testing SQL Connection Strings", t, func() {
t.Helper()
for _, testCase := range sqlStoreTestCases {
Convey(testCase.name, func() {
sqlstore := &SqlStore{}
sqlstore.Cfg = makeSqlStoreTestConfig(testCase.dbType, testCase.dbHost)
sqlstore.readConfig()
connStr, err := sqlstore.buildConnectionString()
So(err, ShouldBeNil)
for _, connSubStr := range testCase.connStrValues {
So(connStr, ShouldContainSubstring, connSubStr)
}
})
}
})
}
func makeSqlStoreTestConfig(dbType string, host string) *setting.Cfg {
cfg := setting.NewCfg()
sec, _ := cfg.Raw.NewSection("database")
sec.NewKey("type", dbType)
sec.NewKey("host", host)
sec.NewKey("user", "user")
sec.NewKey("name", "test_db")
sec.NewKey("password", "pass")
return cfg;
}

View File

@ -4,13 +4,13 @@ import (
"database/sql"
"fmt"
"strconv"
"strings"
_ "github.com/denisenkom/go-mssqldb"
"github.com/go-xorm/core"
"github.com/grafana/grafana/pkg/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/tsdb"
"github.com/grafana/grafana/pkg/util"
)
func init() {
@ -20,7 +20,10 @@ func init() {
func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) {
logger := log.New("tsdb.mssql")
cnnstr := generateConnectionString(datasource)
cnnstr, err := generateConnectionString(datasource)
if err != nil {
return nil, err
}
logger.Debug("getEngine", "connection", cnnstr)
config := tsdb.SqlQueryEndpointConfiguration{
@ -37,7 +40,7 @@ func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin
return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger)
}
func generateConnectionString(datasource *models.DataSource) string {
func generateConnectionString(datasource *models.DataSource) (string, error) {
password := ""
for key, value := range datasource.SecureJsonData.Decrypt() {
if key == "password" {
@ -46,12 +49,11 @@ func generateConnectionString(datasource *models.DataSource) string {
}
}
hostParts := strings.Split(datasource.Url, ":")
if len(hostParts) < 2 {
hostParts = append(hostParts, "1433")
server, port, err := util.SplitIpPort(datasource.Url, "1433")
if err != nil {
return "", err
}
server, port := hostParts[0], hostParts[1]
encrypt := datasource.JsonData.Get("encrypt").MustString("false")
connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;",
server,
@ -63,7 +65,7 @@ func generateConnectionString(datasource *models.DataSource) string {
if encrypt != "false" {
connStr += fmt.Sprintf("encrypt=%s;", encrypt)
}
return connStr
return connStr, nil
}
type mssqlRowTransformer struct {

24
pkg/util/ip.go Normal file
View File

@ -0,0 +1,24 @@
package util
import (
"net"
)
func SplitIpPort(ipStr string, portDefault string) (ip string, port string, err error) {
ipAddr := net.ParseIP(ipStr)
if ipAddr == nil {
// Port was included
ip, port, err = net.SplitHostPort(ipStr)
if err != nil {
return "", "", err
}
} else {
// No port was included
ip = ipAddr.String()
port = portDefault
}
return ip, port, nil
}

43
pkg/util/ip_test.go Normal file
View File

@ -0,0 +1,43 @@
package util
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestSplitIpPort(t *testing.T) {
Convey("When parsing an IPv4 without explicit port", t, func() {
ip, port, err := SplitIpPort("1.2.3.4", "5678")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "1.2.3.4")
So(port, ShouldEqual, "5678")
})
Convey("When parsing an IPv6 without explicit port", t, func() {
ip, port, err := SplitIpPort("::1", "5678")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "::1")
So(port, ShouldEqual, "5678")
})
Convey("When parsing an IPv4 with explicit port", t, func() {
ip, port, err := SplitIpPort("1.2.3.4:56", "78")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "1.2.3.4")
So(port, ShouldEqual, "56")
})
Convey("When parsing an IPv6 with explicit port", t, func() {
ip, port, err := SplitIpPort("[::1]:56", "78")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "::1")
So(port, ShouldEqual, "56")
})
}