Forward oauth tokens after prometheus datasource migration (#43686)

* create the prom client

* implement lru cache of prometheus clients based on auth headers

* linter
This commit is contained in:
Travis Patterson 2022-01-05 13:55:55 -07:00 committed by GitHub
parent 88d17c4998
commit 20b3b2a448
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 504 additions and 130 deletions

View File

@ -1,53 +0,0 @@
package client
import (
"strings"
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/prometheus/client_golang/api"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
)
func Create(url string, httpOpts sdkhttpclient.Options, clientProvider httpclient.Provider, jsonData map[string]interface{}, plog log.Logger) (apiv1.API, error) {
customParamsMiddleware := middleware.CustomQueryParameters(plog)
middlewares := []sdkhttpclient.Middleware{customParamsMiddleware}
if shouldForceGet(jsonData) {
middlewares = append(middlewares, middleware.ForceHttpGet(plog))
}
httpOpts.Middlewares = middlewares
roundTripper, err := clientProvider.GetTransport(httpOpts)
if err != nil {
return nil, err
}
cfg := api.Config{
Address: url,
RoundTripper: roundTripper,
}
client, err := api.NewClient(cfg)
if err != nil {
return nil, err
}
return apiv1.NewAPI(client), nil
}
func shouldForceGet(settingsJson map[string]interface{}) bool {
methodInterface, exists := settingsJson["httpMethod"]
if !exists {
return false
}
method, ok := methodInterface.(string)
if !ok {
return false
}
return strings.ToLower(method) == "get"
}

View File

@ -1,47 +0,0 @@
package client
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestForceGet(t *testing.T) {
t.Run("With nil jsonOpts, should not force get-method", func(t *testing.T) {
var jsonOpts map[string]interface{}
require.False(t, shouldForceGet(jsonOpts))
})
t.Run("With empty jsonOpts, should not force get-method", func(t *testing.T) {
jsonOpts := make(map[string]interface{})
require.False(t, shouldForceGet(jsonOpts))
})
t.Run("With httpMethod=nil, should not not force get-method", func(t *testing.T) {
jsonOpts := map[string]interface{}{
"httpMethod": nil,
}
require.False(t, shouldForceGet(jsonOpts))
})
t.Run("With httpMethod=post, should not force get-method", func(t *testing.T) {
jsonOpts := map[string]interface{}{
"httpMethod": "POST",
}
require.False(t, shouldForceGet(jsonOpts))
})
t.Run("With httpMethod=get, should force get-method", func(t *testing.T) {
jsonOpts := map[string]interface{}{
"httpMethod": "get",
}
require.True(t, shouldForceGet(jsonOpts))
})
t.Run("With httpMethod=GET, should force get-method", func(t *testing.T) {
jsonOpts := map[string]interface{}{
"httpMethod": "GET",
}
require.True(t, shouldForceGet(jsonOpts))
})
}

View File

@ -0,0 +1,55 @@
package promclient
import (
lru "github.com/hashicorp/golang-lru"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
)
const (
noPassThrough = "no-pass-through"
)
type ProviderCache struct {
provider promClientProvider
cache *lru.Cache
jsonData JsonData
}
type promClientProvider interface {
GetClient(map[string]string) (apiv1.API, error)
}
func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error) {
cache, err := lru.New(500)
if err != nil {
return nil, err
}
return &ProviderCache{
provider: p,
cache: cache,
jsonData: jd,
}, nil
}
func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error) {
key := c.key(headers)
if client, ok := c.cache.Get(key); ok {
return client.(apiv1.API), nil
}
client, err := c.provider.GetClient(headers)
if err != nil {
return nil, err
}
c.cache.Add(key, client)
return client, nil
}
func (c *ProviderCache) key(headers map[string]string) string {
if c.jsonData.OauthPassThru {
return headers[authHeader] + headers[idTokenHeader]
}
return noPassThrough
}

View File

@ -0,0 +1,144 @@
package promclient_test
import (
"context"
"errors"
"sort"
"strings"
"testing"
"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
"github.com/stretchr/testify/require"
)
func TestCache_GetClient(t *testing.T) {
t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
tc := setupCacheContext(true)
c, err := tc.providerCache.GetClient(headers)
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(headers)
require.Nil(t, err)
require.Equal(t, c, c2)
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it returns different clients when the auth headers differ", func(t *testing.T) {
tc := setupCacheContext(true)
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
c, err := tc.providerCache.GetClient(h1)
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(h2)
require.Nil(t, err)
require.NotEqual(t, c, c2)
require.Equal(t, 2, tc.clientProvider.numCalls)
})
t.Run("it always returns from the cache when 'oauthPassThru' not set", func(t *testing.T) {
tc := setupCacheContext(false)
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
c, err := tc.providerCache.GetClient(h1)
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(h2)
require.Nil(t, err)
require.Equal(t, c, c2)
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it only accounts for auth headers", func(t *testing.T) {
tc := setupCacheContext(true)
c, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "stuff"})
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "other-stuff"})
require.Nil(t, err)
require.Equal(t, c, c2)
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
tc := setupCacheContext(true)
tc.clientProvider.errors <- errors.New("something bad")
_, err := tc.providerCache.GetClient(headers)
require.EqualError(t, err, "something bad")
c, err := tc.providerCache.GetClient(headers)
require.Nil(t, err)
require.NotNil(t, c)
require.Equal(t, 2, tc.clientProvider.numCalls)
})
}
type cacheTestContext struct {
providerCache *promclient.ProviderCache
clientProvider *fakePromClientProvider
}
func setupCacheContext(oauthPassTrough bool) *cacheTestContext {
fp := newFakePromClientProvider()
p, err := promclient.NewProviderCache(fp, promclient.JsonData{OauthPassThru: oauthPassTrough})
if err != nil {
panic(err)
}
return &cacheTestContext{
providerCache: p,
clientProvider: fp,
}
}
func newFakePromClientProvider() *fakePromClientProvider {
return &fakePromClientProvider{
errors: make(chan error, 1),
}
}
type fakePromClientProvider struct {
headers map[string]string
numCalls int
errors chan error
}
func (p *fakePromClientProvider) GetClient(h map[string]string) (apiv1.API, error) {
p.headers = h
p.numCalls++
var err error
select {
case err = <-p.errors:
default:
}
var config []string
for _, v := range h {
config = append(config, v)
}
sort.Strings(config) //because map
return &fakePromClient{config: strings.Join(config, "")}, err
}
type fakePromClient struct {
apiv1.API
config string
}
func (c *fakePromClient) Config(ctx context.Context) (apiv1.ConfigResult, error) {
return apiv1.ConfigResult{YAML: c.config}, nil
}

View File

@ -0,0 +1,106 @@
package promclient
import (
"strings"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/prometheus/client_golang/api"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
)
const (
authHeader = "Authorization"
idTokenHeader = "X-ID-Token"
)
type Provider struct {
settings backend.DataSourceInstanceSettings
jsonData JsonData
clientProvider httpclient.Provider
log log.Logger
}
func NewProvider(
settings backend.DataSourceInstanceSettings,
jsonData JsonData,
clientProvider httpclient.Provider,
log log.Logger,
) *Provider {
return &Provider{
settings: settings,
jsonData: jsonData,
clientProvider: clientProvider,
log: log,
}
}
type JsonData struct {
Method string `json:"httpMethod"`
OauthPassThru bool `json:"oauthPassThru"`
TimeInterval string `json:"timeInterval"`
}
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
opts, err := p.settings.HTTPClientOptions()
if err != nil {
return nil, err
}
opts.Middlewares = p.middlewares()
if p.jsonData.OauthPassThru {
opts.Headers = authHeaders(headers)
}
// Set SigV4 service namespace
if opts.SigV4 != nil {
opts.SigV4.Service = "aps"
}
roundTripper, err := p.clientProvider.GetTransport(opts)
if err != nil {
return nil, err
}
cfg := api.Config{
Address: p.settings.URL,
RoundTripper: roundTripper,
}
client, err := api.NewClient(cfg)
if err != nil {
return nil, err
}
return apiv1.NewAPI(client), nil
}
func (p *Provider) middlewares() []sdkhttpclient.Middleware {
middlewares := []sdkhttpclient.Middleware{
middleware.CustomQueryParameters(p.log),
sdkhttpclient.CustomHeadersMiddleware(),
}
if strings.ToLower(p.jsonData.Method) == "get" {
middlewares = append(middlewares, middleware.ForceHttpGet(p.log))
}
return middlewares
}
func authHeaders(headers map[string]string) map[string]string {
authHeaders := make(map[string]string)
if v, ok := headers[authHeader]; ok {
authHeaders[authHeader] = v
}
if v, ok := headers[idTokenHeader]; ok {
authHeaders[idTokenHeader] = v
}
return authHeaders
}

View File

@ -0,0 +1,186 @@
package promclient_test
import (
"encoding/json"
"net/http"
"testing"
"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana-plugin-sdk-go/backend"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/stretchr/testify/require"
)
var headers = map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
func TestGetClient(t *testing.T) {
t.Run("it sets the SigV4 service if it exists", func(t *testing.T) {
tc := setup(`{"sigV4Auth":true}`)
setting.SigV4AuthEnabled = true
defer func() { setting.SigV4AuthEnabled = false }()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Equal(t, "aps", tc.httpProvider.opts.SigV4.Service)
})
t.Run("it always uses the custom params and custom headers middlewares", func(t *testing.T) {
tc := setup()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.middlewares(), 2)
require.Contains(t, tc.httpProvider.middlewares(), "prom-custom-query-parameters")
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders")
})
t.Run("oauth pass through", func(t *testing.T) {
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) {
tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Equal(t, headers, tc.httpProvider.opts.Headers)
})
t.Run("it only sets auth headers", func(t *testing.T) {
withNonAuth := map[string]string{"X-Not-Auth": "stuff"}
tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(withNonAuth)
require.Nil(t, err)
require.Equal(t, map[string]string{}, tc.httpProvider.opts.Headers)
})
t.Run("it does not error when headers are nil", func(t *testing.T) {
tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(nil)
require.Nil(t, err)
})
t.Run("it does not set the headers when 'oauthPassThru' is false", func(t *testing.T) {
tc := setup()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.opts.Headers, 0)
})
})
t.Run("force get middleware", func(t *testing.T) {
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) {
tc := setup(`{"httpMethod":"get"}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.middlewares(), 3)
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get")
})
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) {
tc := setup(`{"httpMethod":"GET"}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.middlewares(), 3)
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get")
})
t.Run("it does not add the force-get middleware when httpMethod is POST", func(t *testing.T) {
tc := setup(`{"httpMethod":"POST"}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get")
})
t.Run("it does not add the force-get middleware when json data is nil", func(t *testing.T) {
tc := setup()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get")
})
t.Run("it does not add the force-get middleware when json data is empty", func(t *testing.T) {
tc := setup(`{}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get")
})
t.Run("it does not add the force-get middleware httpMethod is null", func(t *testing.T) {
tc := setup(`{"httpMethod":null}`)
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get")
})
})
}
func setup(jsonData ...string) *testContext {
var rawData []byte
if len(jsonData) > 0 {
rawData = []byte(jsonData[0])
}
var jd promclient.JsonData
_ = json.Unmarshal(rawData, &jd)
settings := backend.DataSourceInstanceSettings{URL: "test-url", JSONData: rawData}
hp := &fakeHttpClientProvider{}
p := promclient.NewProvider(settings, jd, hp, nil)
return &testContext{
httpProvider: hp,
promClientProvider: p,
}
}
type testContext struct {
httpProvider *fakeHttpClientProvider
promClientProvider *promclient.Provider
}
type fakeHttpClientProvider struct {
httpclient.Provider
opts sdkhttpclient.Options
}
func (p *fakeHttpClientProvider) GetTransport(opts ...sdkhttpclient.Options) (http.RoundTripper, error) {
p.opts = opts[0]
return http.DefaultTransport, nil
}
func (p *fakeHttpClientProvider) middlewares() []string {
var middlewareNames []string
for _, m := range p.opts.Middlewares {
mw, ok := m.(sdkhttpclient.MiddlewareName)
if !ok {
panic("unexpected middleware type")
}
middlewareNames = append(middlewareNames, mw.MiddlewareName())
}
return middlewareNames
}

View File

@ -7,7 +7,7 @@ import (
"fmt"
"regexp"
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
@ -57,36 +57,14 @@ func ProvideService(cfg *setting.Cfg, httpClientProvider httpclient.Provider, pl
func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.InstanceFactoryFunc {
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
jsonData := map[string]interface{}{}
var jsonData promclient.JsonData
err := json.Unmarshal(settings.JSONData, &jsonData)
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
httpCliOpts, err := settings.HTTPClientOptions()
if err != nil {
return nil, fmt.Errorf("error getting http options: %w", err)
}
// Set SigV4 service namespace
if httpCliOpts.SigV4 != nil {
httpCliOpts.SigV4.Service = "aps"
}
// timeInterval can be a string or can be missing.
// if it is missing, we set it to empty-string
timeInterval := ""
timeIntervalJson := jsonData["timeInterval"]
if timeIntervalJson != nil {
// if it is not nil, it must be a string
var ok bool
timeInterval, ok = timeIntervalJson.(string)
if !ok {
return nil, errors.New("invalid time-interval provided")
}
}
client, err := client.Create(settings.URL, httpCliOpts, httpClientProvider, jsonData, plog)
p := promclient.NewProvider(settings, jsonData, httpClientProvider, plog)
pc, err := promclient.NewProviderCache(p, jsonData)
if err != nil {
return nil, err
}
@ -94,8 +72,8 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
mdl := DatasourceInfo{
ID: settings.ID,
URL: settings.URL,
TimeInterval: timeInterval,
promClient: client,
TimeInterval: jsonData.TimeInterval,
getClient: pc.GetClient,
}
return mdl, nil

View File

@ -48,7 +48,10 @@ const (
)
func (s *Service) executeTimeSeriesQuery(ctx context.Context, req *backend.QueryDataRequest, dsInfo *DatasourceInfo) (*backend.QueryDataResponse, error) {
client := dsInfo.promClient
client, err := dsInfo.getClient(req.Headers)
if err != nil {
return nil, err
}
result := backend.QueryDataResponse{
Responses: backend.Responses{},

View File

@ -11,9 +11,11 @@ type DatasourceInfo struct {
URL string
TimeInterval string
promClient apiv1.API
getClient clientGetter
}
type clientGetter func(map[string]string) (apiv1.API, error)
type PrometheusQuery struct {
Expr string
Step time.Duration