diff --git a/packages/grafana-ui/src/components/DataSourceSettings/TLSAuthSettings.tsx b/packages/grafana-ui/src/components/DataSourceSettings/TLSAuthSettings.tsx index f468d5f4aae..001b0ad7465 100644 --- a/packages/grafana-ui/src/components/DataSourceSettings/TLSAuthSettings.tsx +++ b/packages/grafana-ui/src/components/DataSourceSettings/TLSAuthSettings.tsx @@ -55,10 +55,10 @@ export const TLSAuthSettings: React.FC = ({ dataSourceCon ` )} > -
TLS Auth Details
+
TLS/SSL Auth Details
diff --git a/pkg/tsdb/postgres/locker.go b/pkg/tsdb/postgres/locker.go new file mode 100644 index 00000000000..62294703881 --- /dev/null +++ b/pkg/tsdb/postgres/locker.go @@ -0,0 +1,85 @@ +package postgres + +import ( + "fmt" + "sync" +) + +// locker is a named reader/writer mutual exclusion lock. +// The lock for each particular key can be held by an arbitrary number of readers or a single writer. +type locker struct { + locks map[interface{}]*sync.RWMutex + locksRW *sync.RWMutex +} + +func newLocker() *locker { + return &locker{ + locks: make(map[interface{}]*sync.RWMutex), + locksRW: new(sync.RWMutex), + } +} + +// Lock locks named rw mutex with specified key for writing. +// If the lock with the same key is already locked for reading or writing, +// Lock blocks until the lock is available. +func (lkr *locker) Lock(key interface{}) { + lk, ok := lkr.getLock(key) + if !ok { + lk = lkr.newLock(key) + } + lk.Lock() +} + +// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is +// not locked for writing on entry to Unlock. +func (lkr *locker) Unlock(key interface{}) { + lk, ok := lkr.getLock(key) + if !ok { + panic(fmt.Errorf("lock for key '%s' not initialized", key)) + } + lk.Unlock() +} + +// RLock locks named rw mutex with specified key for reading. +// +// It should not be used for recursive read locking for the same key; a blocked Lock +// call excludes new readers from acquiring the lock. See the +// documentation on the golang RWMutex type. +func (lkr *locker) RLock(key interface{}) { + lk, ok := lkr.getLock(key) + if !ok { + lk = lkr.newLock(key) + } + lk.RLock() +} + +// RUnlock undoes a single RLock call for specified key; +// it does not affect other simultaneous readers of locker for specified key. +// It is a run-time error if locker for specified key is not locked for reading +func (lkr *locker) RUnlock(key interface{}) { + lk, ok := lkr.getLock(key) + if !ok { + panic(fmt.Errorf("lock for key '%s' not initialized", key)) + } + lk.RUnlock() +} + +func (lkr *locker) newLock(key interface{}) *sync.RWMutex { + lkr.locksRW.Lock() + defer lkr.locksRW.Unlock() + + if lk, ok := lkr.locks[key]; ok { + return lk + } + lk := new(sync.RWMutex) + lkr.locks[key] = lk + return lk +} + +func (lkr *locker) getLock(key interface{}) (*sync.RWMutex, bool) { + lkr.locksRW.RLock() + defer lkr.locksRW.RUnlock() + + lock, ok := lkr.locks[key] + return lock, ok +} diff --git a/pkg/tsdb/postgres/locker_test.go b/pkg/tsdb/postgres/locker_test.go new file mode 100644 index 00000000000..87d8b27dc8a --- /dev/null +++ b/pkg/tsdb/postgres/locker_test.go @@ -0,0 +1,63 @@ +package postgres + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLocker(t *testing.T) { + if testing.Short() { + t.Skip("Tests with Sleep") + } + const notUpdated = "not_updated" + const atThread1 = "at_thread_1" + const atThread2 = "at_thread_2" + t.Run("Should lock for same keys", func(t *testing.T) { + updated := notUpdated + locker := newLocker() + locker.Lock(1) + var wg sync.WaitGroup + wg.Add(1) + defer func() { + locker.Unlock(1) + wg.Wait() + }() + + go func() { + locker.RLock(1) + defer func() { + locker.RUnlock(1) + wg.Done() + }() + require.Equal(t, atThread1, updated, "Value should be updated in different thread") + updated = atThread2 + }() + time.Sleep(time.Millisecond * 10) + require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") + updated = atThread1 + }) + + t.Run("Should not lock for different keys", func(t *testing.T) { + updated := notUpdated + locker := newLocker() + locker.Lock(1) + defer locker.Unlock(1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + locker.RLock(2) + defer func() { + locker.RUnlock(2) + wg.Done() + }() + require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") + updated = atThread2 + }() + wg.Wait() + require.Equal(t, atThread2, updated, "Value should be updated in different thread") + updated = atThread1 + }) +} diff --git a/pkg/tsdb/postgres/postgres.go b/pkg/tsdb/postgres/postgres.go index 12cf29b3c43..504fc18e595 100644 --- a/pkg/tsdb/postgres/postgres.go +++ b/pkg/tsdb/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util/errutil" @@ -13,24 +14,43 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/tsdb" "github.com/grafana/grafana/pkg/tsdb/sqleng" + "xorm.io/core" ) func init() { - tsdb.RegisterTsdbQueryEndpoint("postgres", newPostgresQueryEndpoint) + registry.Register(®istry.Descriptor{ + Name: "PostgresService", + InitPriority: registry.Low, + Instance: &postgresService{}, + }) } -func newPostgresQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) { - logger := log.New("tsdb.postgres") - logger.Debug("Creating Postgres query endpoint") +type postgresService struct { + Cfg *setting.Cfg `inject:""` + logger log.Logger + tlsManager tlsSettingsProvider +} - cnnstr, err := generateConnectionString(datasource, logger) +func (s *postgresService) Init() error { + s.logger = log.New("tsdb.postgres") + s.tlsManager = newTLSManager(s.logger, s.Cfg.DataPath) + tsdb.RegisterTsdbQueryEndpoint("postgres", func(ds *models.DataSource) (tsdb.TsdbQueryEndpoint, error) { + return s.newPostgresQueryEndpoint(ds) + }) + return nil +} + +func (s *postgresService) newPostgresQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) { + s.logger.Debug("Creating Postgres query endpoint") + + cnnstr, err := s.generateConnectionString(datasource) if err != nil { return nil, err } - if setting.Env == setting.Dev { - logger.Debug("getEngine", "connection", cnnstr) + if s.Cfg.Env == setting.Dev { + s.logger.Debug("getEngine", "connection", cnnstr) } config := sqleng.SqlQueryEndpointConfiguration{ @@ -41,18 +61,19 @@ func newPostgresQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndp } queryResultTransformer := postgresQueryResultTransformer{ - log: logger, + log: s.logger, } timescaledb := datasource.JsonData.Get("timescaledb").MustBool(false) - endpoint, err := sqleng.NewSqlQueryEndpoint(&config, &queryResultTransformer, newPostgresMacroEngine(timescaledb), logger) + endpoint, err := sqleng.NewSqlQueryEndpoint(&config, &queryResultTransformer, newPostgresMacroEngine(timescaledb), + s.logger) if err != nil { - logger.Debug("Failed connecting to Postgres", "err", err) + s.logger.Error("Failed connecting to Postgres", "err", err) return nil, err } - logger.Debug("Successfully connected to Postgres") + s.logger.Debug("Successfully connected to Postgres") return endpoint, err } @@ -61,15 +82,13 @@ func escape(input string) string { return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) } -func generateConnectionString(datasource *models.DataSource, logger log.Logger) (string, error) { - tlsMode := strings.TrimSpace(strings.ToLower(datasource.JsonData.Get("sslmode").MustString("verify-full"))) - isTLSDisabled := tlsMode == "disable" - +func (s *postgresService) generateConnectionString(datasource *models.DataSource) (string, error) { var host string var port int + var err error if strings.HasPrefix(datasource.Url, "/") { host = datasource.Url - logger.Debug("Generating connection string with Unix socket specifier", "socket", host) + s.logger.Debug("Generating connection string with Unix socket specifier", "socket", host) } else { sp := strings.SplitN(datasource.Url, ":", 2) host = sp[0] @@ -80,41 +99,41 @@ func generateConnectionString(datasource *models.DataSource, logger log.Logger) return "", errutil.Wrapf(err, "invalid port in host specifier %q", sp[1]) } - logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port) + s.logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port) } else { - logger.Debug("Generating connection string with network host", "host", host) + s.logger.Debug("Generating connection string with network host", "host", host) } } - connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s' sslmode='%s'", - escape(datasource.User), escape(datasource.DecryptedPassword()), escape(host), escape(datasource.Database), - escape(tlsMode)) + connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", + escape(datasource.User), escape(datasource.DecryptedPassword()), escape(host), escape(datasource.Database)) if port > 0 { connStr += fmt.Sprintf(" port=%d", port) } - if isTLSDisabled { - logger.Debug("Postgres TLS/SSL is disabled") - } else { - logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsMode) - // Attach root certificate if provided - if tlsRootCert := datasource.JsonData.Get("sslRootCertFile").MustString(""); tlsRootCert != "" { - logger.Debug("Setting server root certificate", "tlsRootCert", tlsRootCert) - connStr += fmt.Sprintf(" sslrootcert='%s'", tlsRootCert) - } - - // Attach client certificate and key if both are provided - tlsCert := datasource.JsonData.Get("sslCertFile").MustString("") - tlsKey := datasource.JsonData.Get("sslKeyFile").MustString("") - if tlsCert != "" && tlsKey != "" { - logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsCert, "tlsKey", tlsKey) - connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", tlsCert, tlsKey) - } else if tlsCert != "" || tlsKey != "" { - return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified") - } + tlsSettings, err := s.tlsManager.getTLSSettings(datasource) + if err != nil { + return "", err } - logger.Debug("Generated Postgres connection string successfully") + connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode)) + + // Attach root certificate if provided + // Attach root certificate if provided + if tlsSettings.RootCertFile != "" { + s.logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile) + connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile)) + } + + // Attach client certificate and key if both are provided + if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" { + s.logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile) + connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile)) + } else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" { + return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified") + } + + s.logger.Debug("Generated Postgres connection string successfully") return connStr, nil } diff --git a/pkg/tsdb/postgres/postgres_test.go b/pkg/tsdb/postgres/postgres_test.go index ea649edd1c9..66e7d448e1f 100644 --- a/pkg/tsdb/postgres/postgres_test.go +++ b/pkg/tsdb/postgres/postgres_test.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/sqlutil" + "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb" "github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/stretchr/testify/assert" @@ -27,84 +28,110 @@ import ( // Test generateConnectionString. func TestGenerateConnectionString(t *testing.T) { - logger := log.New("tsdb.postgres") + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() testCases := []struct { - desc string - host string - user string - password string - database string - tlsMode string - expConnStr string - expErr string + desc string + host string + user string + password string + database string + tlsSettings tlsSettings + expConnStr string + expErr string + uid string }{ { - desc: "Unix socket host", - host: "/var/run/postgresql", - user: "user", - password: "password", - database: "database", - expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", + desc: "Unix socket host", + host: "/var/run/postgresql", + user: "user", + password: "password", + database: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", }, { - desc: "TCP host", - host: "host", - user: "user", - password: "password", - database: "database", - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", + desc: "TCP host", + host: "host", + user: "user", + password: "password", + database: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", }, { - desc: "TCP/port host", - host: "host:1234", - user: "user", - password: "password", - database: "database", - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' port=1234", + desc: "TCP/port host", + host: "host:1234", + user: "user", + password: "password", + database: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='host' dbname='database' port=1234 sslmode='verify-full'", }, { - desc: "Invalid port", - host: "host:invalid", + desc: "Invalid port", + host: "host:invalid", + user: "user", + database: "database", + tlsSettings: tlsSettings{}, + expErr: "invalid port in host specifier", + }, + { + desc: "Password with single quote and backslash", + host: "host", + user: "user", + password: `p'\assword`, + database: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, + }, + { + desc: "Custom TLS mode disabled", + host: "host", + user: "user", + password: "password", + database: "database", + tlsSettings: tlsSettings{Mode: "disable"}, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", + }, + { + desc: "Custom TLS mode verify-full with certificate files", + host: "host", user: "user", + password: "password", database: "database", - expErr: "invalid port in host specifier", - }, - { - desc: "Password with single quote and backslash", - host: "host", - user: "user", - password: `p'\assword`, - database: "database", - expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, - }, - { - desc: "Custom TLS/SSL mode", - host: "host", - user: "user", - password: "password", - database: "database", - tlsMode: "disable", - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", + tlsSettings: tlsSettings{ + Mode: "verify-full", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + }, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' " + + "sslrootcert='i/am/coding/ca.crt' sslcert='i/am/coding/client.crt' sslkey='i/am/coding/client.key'", }, } for _, tt := range testCases { t.Run(tt.desc, func(t *testing.T) { - data := map[string]interface{}{} - if tt.tlsMode != "" { - data["sslmode"] = tt.tlsMode + svc := postgresService{ + Cfg: cfg, + logger: log.New("tsdb.postgres"), + tlsManager: &tlsTestManager{settings: tt.tlsSettings}, } + ds := &models.DataSource{ Url: tt.host, User: tt.user, Password: tt.password, Database: tt.database, - JsonData: simplejson.NewFromAny(data), + Uid: tt.uid, } - connStr, err := generateConnectionString(ds, logger) + + connStr, err := svc.generateConnectionString(ds) + if tt.expErr == "" { require.NoError(t, err, tt.desc) - assert.Equal(t, tt.expConnStr, connStr, tt.desc) + assert.Equal(t, tt.expConnStr, connStr) } else { require.Error(t, err, tt.desc) assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), @@ -127,7 +154,7 @@ func TestPostgres(t *testing.T) { runPostgresTests := false // runPostgresTests := true - if !(sqlstore.IsTestDbPostgres() || runPostgresTests) { + if !sqlstore.IsTestDbPostgres() && !runPostgresTests { t.Skip() } @@ -146,7 +173,15 @@ func TestPostgres(t *testing.T) { return sql, nil } - endpoint, err := newPostgresQueryEndpoint(&models.DataSource{ + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() + svc := postgresService{ + Cfg: cfg, + logger: log.New("tsdb.postgres"), + tlsManager: &tlsTestManager{settings: tlsSettings{Mode: "disable"}}, + } + + endpoint, err := svc.newPostgresQueryEndpoint(&models.DataSource{ JsonData: simplejson.New(), SecureJsonData: securejsondata.SecureJsonData{}, }) @@ -483,9 +518,9 @@ func TestPostgres(t *testing.T) { ValueTwo int64 `xorm:"integer 'valueTwo'"` } - if exist, err := sess.IsTableExist(metric_values{}); err != nil || exist { + if exists, err := sess.IsTableExist(metric_values{}); err != nil || exists { require.NoError(t, err) - err = sess.DropTable(metric_values{}) + err := sess.DropTable(metric_values{}) require.NoError(t, err) } err := sess.CreateTable(metric_values{}) @@ -1084,9 +1119,7 @@ func InitPostgresTestDB(t *testing.T) *xorm.Engine { testDB := sqlutil.PostgresTestDB() x, err := xorm.NewEngine(testDB.DriverName, strings.Replace(testDB.ConnStr, "dbname=grafanatest", "dbname=grafanadstest", 1)) - if err != nil { - t.Fatalf("Failed to init postgres DB %v", err) - } + require.NoError(t, err, "Failed to init postgres DB") x.DatabaseTZ = time.UTC x.TZLocation = time.UTC @@ -1108,3 +1141,11 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim return timeRange } + +type tlsTestManager struct { + settings tlsSettings +} + +func (m *tlsTestManager) getTLSSettings(datasource *models.DataSource) (tlsSettings, error) { + return m.settings, nil +} diff --git a/pkg/tsdb/postgres/tlsmanager.go b/pkg/tsdb/postgres/tlsmanager.go new file mode 100644 index 00000000000..eadb67bf491 --- /dev/null +++ b/pkg/tsdb/postgres/tlsmanager.go @@ -0,0 +1,228 @@ +package postgres + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + + "github.com/grafana/grafana/pkg/infra/fs" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/models" +) + +var validateCertFunc = validateCertFilePaths +var writeCertFileFunc = writeCertFile + +type tlsSettingsProvider interface { + getTLSSettings(datasource *models.DataSource) (tlsSettings, error) +} + +type datasourceCacheManager struct { + locker *locker + cache sync.Map +} + +type tlsManager struct { + logger log.Logger + dsCacheInstance datasourceCacheManager + dataPath string +} + +func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider { + return &tlsManager{ + logger: logger, + dataPath: dataPath, + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + } +} + +type tlsSettings struct { + Mode string + ConfigurationMethod string + RootCertFile string + CertFile string + CertKeyFile string +} + +func (m *tlsManager) getTLSSettings(datasource *models.DataSource) (tlsSettings, error) { + tlsMode := strings.TrimSpace(strings.ToLower(datasource.JsonData.Get("sslmode").MustString("verify-full"))) + isTLSDisabled := tlsMode == "disable" + + settings := tlsSettings{} + settings.Mode = tlsMode + + if isTLSDisabled { + m.logger.Debug("Postgres TLS/SSL is disabled") + return settings, nil + } + + m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsMode) + + settings.ConfigurationMethod = strings.TrimSpace( + strings.ToLower(datasource.JsonData.Get("tlsConfigurationMethod").MustString("file-path"))) + + if settings.ConfigurationMethod == "file-content" { + if err := m.writeCertFiles(datasource, &settings); err != nil { + return settings, err + } + } else { + settings.RootCertFile = datasource.JsonData.Get("sslRootCertFile").MustString("") + settings.CertFile = datasource.JsonData.Get("sslCertFile").MustString("") + settings.CertKeyFile = datasource.JsonData.Get("sslKeyFile").MustString("") + if err := validateCertFunc(settings.RootCertFile, settings.CertFile, settings.CertKeyFile); err != nil { + return settings, err + } + } + return settings, nil +} + +type certFileType int + +const ( + rootCert = iota + clientCert + clientKey +) + +func (t certFileType) String() string { + switch t { + case rootCert: + return "root certificate" + case clientCert: + return "client certificate" + case clientKey: + return "client key" + default: + panic(fmt.Sprintf("Unrecognized certFileType %d", t)) + } +} + +func getFileName(dataDir string, fileType certFileType) string { + var filename string + switch fileType { + case rootCert: + filename = "root.crt" + case clientCert: + filename = "client.crt" + case clientKey: + filename = "client.key" + default: + panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String())) + } + generatedFilePath := filepath.Join(dataDir, filename) + return generatedFilePath +} + +// writeCertFile writes a certificate file. +func writeCertFile( + ds *models.DataSource, logger log.Logger, fileContent string, generatedFilePath string) error { + fileContent = strings.TrimSpace(fileContent) + if fileContent != "" { + logger.Debug("Writing cert file", "path", generatedFilePath) + if err := ioutil.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil { + return err + } + // Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail + if err := os.Chmod(generatedFilePath, 0600); err != nil { + return err + } + return nil + } + + logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath) + exists, err := fs.Exists(generatedFilePath) + if err != nil { + return err + } + if exists { + if err := os.Remove(generatedFilePath); err != nil { + return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err) + } + } + return nil +} + +func (m *tlsManager) writeCertFiles(ds *models.DataSource, settings *tlsSettings) error { + m.logger.Debug("Writing TLS certificate files to disk") + decrypted := ds.DecryptedValues() + tlsRootCert := decrypted["tlsCACert"] + tlsClientCert := decrypted["tlsClientCert"] + tlsClientKey := decrypted["tlsClientKey"] + + if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" { + m.logger.Debug("No TLS/SSL certificates provided") + } + + // Calculate all files path + workDir := filepath.Join(m.dataPath, "tls", ds.Uid+"generatedTLSCerts") + settings.RootCertFile = getFileName(workDir, rootCert) + settings.CertFile = getFileName(workDir, clientCert) + settings.CertKeyFile = getFileName(workDir, clientKey) + + // Find datasource in the cache, if found, skip writing files + cacheKey := strconv.Itoa(int(ds.Id)) + m.dsCacheInstance.locker.RLock(cacheKey) + item, ok := m.dsCacheInstance.cache.Load(cacheKey) + m.dsCacheInstance.locker.RUnlock(cacheKey) + if ok { + if item.(int) == ds.Version { + return nil + } + } + + m.dsCacheInstance.locker.Lock(cacheKey) + defer m.dsCacheInstance.locker.Unlock(cacheKey) + + item, ok = m.dsCacheInstance.cache.Load(cacheKey) + if ok { + if item.(int) == ds.Version { + return nil + } + } + + // Write certification directory and files + exists, err := fs.Exists(workDir) + if err != nil { + return err + } + if !exists { + if err := os.MkdirAll(workDir, 0700); err != nil { + return err + } + } + + if err = writeCertFileFunc(ds, m.logger, tlsRootCert, settings.RootCertFile); err != nil { + return err + } + if err = writeCertFileFunc(ds, m.logger, tlsClientCert, settings.CertFile); err != nil { + return err + } + if err = writeCertFileFunc(ds, m.logger, tlsClientKey, settings.CertKeyFile); err != nil { + return err + } + + // Update datasource cache + m.dsCacheInstance.cache.Store(cacheKey, ds.Version) + return nil +} + +// validateCertFilePaths validates configured certificate file paths. +func validateCertFilePaths(rootCert, clientCert, clientKey string) error { + for _, fpath := range []string{rootCert, clientCert, clientKey} { + if fpath == "" { + continue + } + exists, err := fs.Exists(fpath) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("certificate file %q doesn't exist", fpath) + } + } + return nil +} diff --git a/pkg/tsdb/postgres/tlsmanager_test.go b/pkg/tsdb/postgres/tlsmanager_test.go new file mode 100644 index 00000000000..d08b85bfe77 --- /dev/null +++ b/pkg/tsdb/postgres/tlsmanager_test.go @@ -0,0 +1,290 @@ +package postgres + +import ( + "fmt" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + + "github.com/grafana/grafana/pkg/components/securejsondata" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + _ "github.com/lib/pq" +) + +var writeCertFileCallNum int + +// TestDataSourceCacheManager is to test the Cache manager +func TestDataSourceCacheManager(t *testing.T) { + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() + mng := tlsManager{ + logger: log.New("tsdb.postgres"), + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + dataPath: cfg.DataPath, + } + + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "sslmode": "verify-full", + "tlsConfigurationMethod": "file-content", + }) + secureJSONData := securejsondata.GetEncryptedJsonData(map[string]string{ + "tlsClientCert": "I am client certification", + "tlsClientKey": "I am client key", + "tlsCACert": "I am CA certification", + }) + + mockValidateCertFilePaths() + t.Cleanup(resetValidateCertFilePaths) + + t.Run("Check datasource cache creation", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(10) + for id := int64(1); id <= 10; id++ { + go func(id int64) { + ds := &models.DataSource{ + Id: id, + Version: 1, + Database: "database", + JsonData: jsonData, + SecureJsonData: secureJSONData, + Uid: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(ds, &s) + require.NoError(t, err) + wg.Done() + }(id) + } + wg.Wait() + + t.Run("check cache creation is succeed", func(t *testing.T) { + for id := int64(1); id <= 10; id++ { + version, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id))) + require.True(t, ok) + require.Equal(t, int(1), version) + } + }) + }) + + t.Run("Check datasource cache modification", func(t *testing.T) { + t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) { + mockWriteCertFile() + t.Cleanup(resetWriteCertFile) + var wg1 sync.WaitGroup + wg1.Add(5) + for id := int64(1); id <= 5; id++ { + go func(id int64) { + ds := &models.DataSource{ + Id: 1, + Version: 2, + Database: "database", + JsonData: jsonData, + SecureJsonData: secureJSONData, + Uid: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(ds, &s) + require.NoError(t, err) + wg1.Done() + }(id) + } + wg1.Wait() + assert.Equal(t, writeCertFileCallNum, 3) + }) + + t.Run("cache is updated with the last datasource version", func(t *testing.T) { + dsV2 := &models.DataSource{ + Id: 1, + Version: 2, + Database: "database", + JsonData: jsonData, + SecureJsonData: secureJSONData, + Uid: "testData", + } + dsV3 := &models.DataSource{ + Id: 1, + Version: 3, + Database: "database", + JsonData: jsonData, + SecureJsonData: secureJSONData, + Uid: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(dsV2, &s) + require.NoError(t, err) + err = mng.writeCertFiles(dsV3, &s) + require.NoError(t, err) + version, ok := mng.dsCacheInstance.cache.Load("1") + require.True(t, ok) + require.Equal(t, int(3), version) + }) + }) +} + +// Test getFileName + +func TestGetFileName(t *testing.T) { + testCases := []struct { + desc string + datadir string + fileType certFileType + expErr string + expectedGeneratedPath string + }{ + { + desc: "Get File Name for root certification", + datadir: ".", + fileType: rootCert, + expectedGeneratedPath: "root.crt", + }, + { + desc: "Get File Name for client certification", + datadir: ".", + fileType: clientCert, + expectedGeneratedPath: "client.crt", + }, + { + desc: "Get File Name for client certification", + datadir: ".", + fileType: clientKey, + expectedGeneratedPath: "client.key", + }, + } + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + generatedPath := getFileName(tt.datadir, tt.fileType) + assert.Equal(t, tt.expectedGeneratedPath, generatedPath) + }) + } +} + +// Test getTLSSettings. +func TestGetTLSSettings(t *testing.T) { + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() + + mockValidateCertFilePaths() + t.Cleanup(resetValidateCertFilePaths) + testCases := []struct { + desc string + expErr string + jsonData map[string]interface{} + secureJSONData map[string]string + uid string + tlsSettings tlsSettings + version int + }{ + { + desc: "Custom TLS authentication disabled", + version: 1, + jsonData: map[string]interface{}{ + "sslmode": "disable", + "sslRootCertFile": "i/am/coding/ca.crt", + "sslCertFile": "i/am/coding/client.crt", + "sslKeyFile": "i/am/coding/client.key", + "tlsConfigurationMethod": "file-path", + }, + tlsSettings: tlsSettings{Mode: "disable"}, + }, + { + desc: "Custom TLS authentication with file path", + version: 2, + jsonData: map[string]interface{}{ + "sslmode": "verify-full", + "sslRootCertFile": "i/am/coding/ca.crt", + "sslCertFile": "i/am/coding/client.crt", + "sslKeyFile": "i/am/coding/client.key", + "tlsConfigurationMethod": "file-path", + }, + tlsSettings: tlsSettings{ + Mode: "verify-full", + ConfigurationMethod: "file-path", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + }, + }, + { + desc: "Custom TLS mode verify-full with certificate files content", + version: 3, + uid: "xxx", + jsonData: map[string]interface{}{ + "sslmode": "verify-full", + "tlsConfigurationMethod": "file-content", + }, + secureJSONData: map[string]string{ + "tlsCACert": "I am CA certification", + "tlsClientCert": "I am client certification", + "tlsClientKey": "I am client key", + }, + tlsSettings: tlsSettings{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), + CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), + CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), + }, + }, + } + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + var settings tlsSettings + var err error + mng := tlsManager{ + logger: log.New("tsdb.postgres"), + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + dataPath: cfg.DataPath, + } + + jsonData := simplejson.NewFromAny(tt.jsonData) + ds := &models.DataSource{ + JsonData: jsonData, + SecureJsonData: securejsondata.GetEncryptedJsonData(tt.secureJSONData), + Uid: tt.uid, + Version: tt.version, + } + + settings, err = mng.getTLSSettings(ds) + + if tt.expErr == "" { + require.NoError(t, err, tt.desc) + assert.Equal(t, tt.tlsSettings, settings) + } else { + require.Error(t, err, tt.desc) + assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), + fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr)) + } + }) + } +} + +func mockValidateCertFilePaths() { + validateCertFunc = func(rootCert, clientCert, clientKey string) error { + return nil + } +} + +func resetValidateCertFilePaths() { + validateCertFunc = validateCertFilePaths +} + +func mockWriteCertFile() { + writeCertFileCallNum = 0 + writeCertFileFunc = func(ds *models.DataSource, logger log.Logger, fileContent string, generatedFilePath string) error { + writeCertFileCallNum++ + return nil + } +} + +func resetWriteCertFile() { + writeCertFileCallNum = 0 + writeCertFileFunc = writeCertFile +} diff --git a/public/app/features/datasources/partials/tls_auth_settings.html b/public/app/features/datasources/partials/tls_auth_settings.html index 7760edb8274..f15193bf63b 100644 --- a/public/app/features/datasources/partials/tls_auth_settings.html +++ b/public/app/features/datasources/partials/tls_auth_settings.html @@ -1,7 +1,7 @@
-
TLS Auth Details
- TLS Certs are encrypted and stored in the Grafana database. +
TLS/SSL Auth Details
+ TLS/SSL certificates are encrypted and stored in the Grafana database.
diff --git a/public/app/plugins/datasource/postgres/config_ctrl.ts b/public/app/plugins/datasource/postgres/config_ctrl.ts index 6e17c0d4217..85449dcaf68 100644 --- a/public/app/plugins/datasource/postgres/config_ctrl.ts +++ b/public/app/plugins/datasource/postgres/config_ctrl.ts @@ -19,11 +19,13 @@ export class PostgresConfigCtrl { constructor($scope: any, datasourceSrv: DatasourceSrv) { this.datasourceSrv = datasourceSrv; this.current.jsonData.sslmode = this.current.jsonData.sslmode || 'verify-full'; + this.current.jsonData.tlsConfigurationMethod = this.current.jsonData.tlsConfigurationMethod || 'file-path'; this.current.jsonData.postgresVersion = this.current.jsonData.postgresVersion || 903; this.showTimescaleDBHelp = false; this.autoDetectFeatures(); this.onPasswordReset = createResetHandler(this, PasswordFieldEnum.Password); this.onPasswordChange = createChangeHandler(this, PasswordFieldEnum.Password); + this.tlsModeMapping(); } autoDetectFeatures() { @@ -62,6 +64,18 @@ export class PostgresConfigCtrl { this.showTimescaleDBHelp = !this.showTimescaleDBHelp; } + tlsModeMapping() { + if (this.current.jsonData.sslmode === 'disable') { + this.current.jsonData.tlsAuth = false; + this.current.jsonData.tlsAuthWithCACert = false; + this.current.jsonData.tlsSkipVerify = true; + } else { + this.current.jsonData.tlsAuth = true; + this.current.jsonData.tlsAuthWithCACert = true; + this.current.jsonData.tlsSkipVerify = false; + } + } + // the value portion is derived from postgres server_version_num/100 postgresVersions = [ { name: '9.3', value: 903 }, diff --git a/public/app/plugins/datasource/postgres/partials/config.html b/public/app/plugins/datasource/postgres/partials/config.html index 98ddc0b29e8..9bff000e216 100644 --- a/public/app/plugins/datasource/postgres/partials/config.html +++ b/public/app/plugins/datasource/postgres/partials/config.html @@ -28,45 +28,72 @@ />
+
+ ng-init="ctrl.current.jsonData.sslmode" ng-change="ctrl.tlsModeMapping()"> This option determines whether or with what priority a secure TLS/SSL TCP/IP connection will be negotiated with the server.
-
+ +
+ +
+ + + This option determines how TLS/SSL certifications are configured. Selecting File system path will allow + you to configure certificates by specifying paths to existing certificates on the local file system where + Grafana is running. Be sure that the file is readable by the user executing the Grafana process.

+ + Selecting Certificate content will allow you to configure certificates by specifying its content. + The content will be stored encrypted in Grafana's database. When connecting to the database the certificates + will be written as files to Grafana's configured data path on the local file system. +
+
+
+
+ +
+
+
TLS/SSL Auth Details
+
+
TLS/SSL Root Certificate - If the selected TLS/SSL mode requires a server root certificate, provide the path to the file here. - Be sure that the file is readable by the user executing the grafana process. + If the selected TLS/SSL mode requires a server root certificate, provide the path to the file here.
-
+
TLS/SSL Client Certificate + placeholder="TLS/SSL client cert file"> To authenticate with an TLS/SSL client certificate, provide the path to the file here. Be sure that the file is readable by the user executing the grafana process.
-
+
TLS/SSL Client Key + placeholder="TLS/SSL client key file"> To authenticate with a client TLS/SSL certificate, provide the path to the corresponding key file here. Be sure that the file is only readable by the user executing the grafana process.
+ + Connection limits @@ -74,7 +101,7 @@
Max open + ng-model="ctrl.current.jsonData.maxOpenConns" placeholder="unlimited"> The maximum number of open connections to the database. If Max idle connections is greater than 0 and the Max open connections is less than Max idle connections, then Max idle connections will be @@ -85,7 +112,7 @@
Max idle + ng-model="ctrl.current.jsonData.maxIdleConns" placeholder="2"> The maximum number of connections in the idle connection pool. If Max open connections is greater than 0 but less than the Max idle connections, then the Max idle connections will be reduced to match the @@ -95,7 +122,7 @@
Max lifetime + ng-model="ctrl.current.jsonData.connMaxLifetime" placeholder="14400"> The maximum amount of time in seconds a connection may be reused. If set to 0, connections are reused forever.