mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Propagate all headers (#43812)
* Propagate all headers * stable header order
This commit is contained in:
@@ -1,25 +1,23 @@
|
||||
package promclient
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
noPassThrough = "no-pass-through"
|
||||
)
|
||||
|
||||
type ProviderCache struct {
|
||||
provider promClientProvider
|
||||
cache *lru.Cache
|
||||
jsonData JsonData
|
||||
}
|
||||
|
||||
type promClientProvider interface {
|
||||
GetClient(map[string]string) (apiv1.API, error)
|
||||
}
|
||||
|
||||
func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error) {
|
||||
func NewProviderCache(p promClientProvider) (*ProviderCache, error) {
|
||||
cache, err := lru.New(500)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -28,7 +26,6 @@ func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error)
|
||||
return &ProviderCache{
|
||||
provider: p,
|
||||
cache: cache,
|
||||
jsonData: jd,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -48,8 +45,12 @@ func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error)
|
||||
}
|
||||
|
||||
func (c *ProviderCache) key(headers map[string]string) string {
|
||||
if c.jsonData.OauthPassThru {
|
||||
return headers[authHeader] + headers[idTokenHeader]
|
||||
vals := make([]string, len(headers))
|
||||
var i int
|
||||
for _, v := range headers {
|
||||
vals[i] = v
|
||||
i++
|
||||
}
|
||||
return noPassThrough
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, "")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
func TestCache_GetClient(t *testing.T) {
|
||||
t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
|
||||
tc := setupCacheContext(true)
|
||||
tc := setupCacheContext()
|
||||
|
||||
c, err := tc.providerCache.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
@@ -28,8 +28,8 @@ func TestCache_GetClient(t *testing.T) {
|
||||
require.Equal(t, 1, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it returns different clients when the auth headers differ", func(t *testing.T) {
|
||||
tc := setupCacheContext(true)
|
||||
t.Run("it returns different clients when the headers differ", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
|
||||
|
||||
@@ -43,10 +43,10 @@ func TestCache_GetClient(t *testing.T) {
|
||||
require.Equal(t, 2, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it always returns from the cache when 'oauthPassThru' not set", func(t *testing.T) {
|
||||
tc := setupCacheContext(false)
|
||||
t.Run("it returns from the cache when headers are the same", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
|
||||
h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
|
||||
c, err := tc.providerCache.GetClient(h1)
|
||||
require.Nil(t, err)
|
||||
@@ -58,21 +58,8 @@ func TestCache_GetClient(t *testing.T) {
|
||||
require.Equal(t, 1, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it only accounts for auth headers", func(t *testing.T) {
|
||||
tc := setupCacheContext(true)
|
||||
|
||||
c, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "stuff"})
|
||||
require.Nil(t, err)
|
||||
|
||||
c2, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "other-stuff"})
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, c, c2)
|
||||
require.Equal(t, 1, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
|
||||
tc := setupCacheContext(true)
|
||||
tc := setupCacheContext()
|
||||
tc.clientProvider.errors <- errors.New("something bad")
|
||||
|
||||
_, err := tc.providerCache.GetClient(headers)
|
||||
@@ -91,9 +78,9 @@ type cacheTestContext struct {
|
||||
clientProvider *fakePromClientProvider
|
||||
}
|
||||
|
||||
func setupCacheContext(oauthPassTrough bool) *cacheTestContext {
|
||||
func setupCacheContext() *cacheTestContext {
|
||||
fp := newFakePromClientProvider()
|
||||
p, err := promclient.NewProviderCache(fp, promclient.JsonData{OauthPassThru: oauthPassTrough})
|
||||
p, err := promclient.NewProviderCache(fp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,6 @@ import (
|
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
authHeader = "Authorization"
|
||||
idTokenHeader = "X-ID-Token"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
settings backend.DataSourceInstanceSettings
|
||||
jsonData JsonData
|
||||
@@ -41,9 +36,8 @@ func NewProvider(
|
||||
}
|
||||
|
||||
type JsonData struct {
|
||||
Method string `json:"httpMethod"`
|
||||
OauthPassThru bool `json:"oauthPassThru"`
|
||||
TimeInterval string `json:"timeInterval"`
|
||||
Method string `json:"httpMethod"`
|
||||
TimeInterval string `json:"timeInterval"`
|
||||
}
|
||||
|
||||
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
|
||||
@@ -53,9 +47,7 @@ func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
|
||||
}
|
||||
|
||||
opts.Middlewares = p.middlewares()
|
||||
if p.jsonData.OauthPassThru {
|
||||
opts.Headers = authHeaders(headers)
|
||||
}
|
||||
opts.Headers = reqHeaders(headers)
|
||||
|
||||
// Set SigV4 service namespace
|
||||
if opts.SigV4 != nil {
|
||||
@@ -92,15 +84,11 @@ func (p *Provider) middlewares() []sdkhttpclient.Middleware {
|
||||
return middlewares
|
||||
}
|
||||
|
||||
func authHeaders(headers map[string]string) map[string]string {
|
||||
authHeaders := make(map[string]string)
|
||||
if v, ok := headers[authHeader]; ok {
|
||||
authHeaders[authHeader] = v
|
||||
func reqHeaders(headers map[string]string) map[string]string {
|
||||
// copy to avoid changing the original map
|
||||
h := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
h[k] = v
|
||||
}
|
||||
|
||||
if v, ok := headers[idTokenHeader]; ok {
|
||||
authHeaders[idTokenHeader] = v
|
||||
}
|
||||
|
||||
return authHeaders
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestGetClient(t *testing.T) {
|
||||
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders")
|
||||
})
|
||||
|
||||
t.Run("oauth pass through", func(t *testing.T) {
|
||||
t.Run("extra headers", func(t *testing.T) {
|
||||
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) {
|
||||
tc := setup(`{"oauthPassThru":true}`)
|
||||
_, err := tc.promClientProvider.GetClient(headers)
|
||||
@@ -52,14 +52,14 @@ func TestGetClient(t *testing.T) {
|
||||
require.Equal(t, headers, tc.httpProvider.opts.Headers)
|
||||
})
|
||||
|
||||
t.Run("it only sets auth headers", func(t *testing.T) {
|
||||
t.Run("it sets all headers", func(t *testing.T) {
|
||||
withNonAuth := map[string]string{"X-Not-Auth": "stuff"}
|
||||
|
||||
tc := setup(`{"oauthPassThru":true}`)
|
||||
_, err := tc.promClientProvider.GetClient(withNonAuth)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, map[string]string{}, tc.httpProvider.opts.Headers)
|
||||
require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers)
|
||||
})
|
||||
|
||||
t.Run("it does not error when headers are nil", func(t *testing.T) {
|
||||
@@ -68,14 +68,6 @@ func TestGetClient(t *testing.T) {
|
||||
_, err := tc.promClientProvider.GetClient(nil)
|
||||
require.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("it does not set the headers when 'oauthPassThru' is false", func(t *testing.T) {
|
||||
tc := setup()
|
||||
_, err := tc.promClientProvider.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, tc.httpProvider.opts.Headers, 0)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("force get middleware", func(t *testing.T) {
|
||||
|
||||
@@ -64,7 +64,7 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
|
||||
}
|
||||
|
||||
p := promclient.NewProvider(settings, jsonData, httpClientProvider, plog)
|
||||
pc, err := promclient.NewProviderCache(p, jsonData)
|
||||
pc, err := promclient.NewProviderCache(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user