Propagate all headers (#43812)

* Propagate all headers

* stable header order
This commit is contained in:
Travis Patterson
2022-01-07 12:45:26 -07:00
committed by GitHub
parent 546818819b
commit 9eb82f9fff
7 changed files with 215 additions and 70 deletions

View File

@@ -1,25 +1,23 @@
package promclient
import (
"sort"
"strings"
lru "github.com/hashicorp/golang-lru"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
)
const (
noPassThrough = "no-pass-through"
)
type ProviderCache struct {
provider promClientProvider
cache *lru.Cache
jsonData JsonData
}
type promClientProvider interface {
GetClient(map[string]string) (apiv1.API, error)
}
func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error) {
func NewProviderCache(p promClientProvider) (*ProviderCache, error) {
cache, err := lru.New(500)
if err != nil {
return nil, err
@@ -28,7 +26,6 @@ func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error)
return &ProviderCache{
provider: p,
cache: cache,
jsonData: jd,
}, nil
}
@@ -48,8 +45,12 @@ func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error)
}
func (c *ProviderCache) key(headers map[string]string) string {
if c.jsonData.OauthPassThru {
return headers[authHeader] + headers[idTokenHeader]
vals := make([]string, len(headers))
var i int
for _, v := range headers {
vals[i] = v
i++
}
return noPassThrough
sort.Strings(vals)
return strings.Join(vals, "")
}

View File

@@ -16,7 +16,7 @@ import (
func TestCache_GetClient(t *testing.T) {
t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
tc := setupCacheContext(true)
tc := setupCacheContext()
c, err := tc.providerCache.GetClient(headers)
require.Nil(t, err)
@@ -28,8 +28,8 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it returns different clients when the auth headers differ", func(t *testing.T) {
tc := setupCacheContext(true)
t.Run("it returns different clients when the headers differ", func(t *testing.T) {
tc := setupCacheContext()
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
@@ -43,10 +43,10 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 2, tc.clientProvider.numCalls)
})
t.Run("it always returns from the cache when 'oauthPassThru' not set", func(t *testing.T) {
tc := setupCacheContext(false)
t.Run("it returns from the cache when headers are the same", func(t *testing.T) {
tc := setupCacheContext()
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
c, err := tc.providerCache.GetClient(h1)
require.Nil(t, err)
@@ -58,21 +58,8 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it only accounts for auth headers", func(t *testing.T) {
tc := setupCacheContext(true)
c, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "stuff"})
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "other-stuff"})
require.Nil(t, err)
require.Equal(t, c, c2)
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
tc := setupCacheContext(true)
tc := setupCacheContext()
tc.clientProvider.errors <- errors.New("something bad")
_, err := tc.providerCache.GetClient(headers)
@@ -91,9 +78,9 @@ type cacheTestContext struct {
clientProvider *fakePromClientProvider
}
func setupCacheContext(oauthPassTrough bool) *cacheTestContext {
func setupCacheContext() *cacheTestContext {
fp := newFakePromClientProvider()
p, err := promclient.NewProviderCache(fp, promclient.JsonData{OauthPassThru: oauthPassTrough})
p, err := promclient.NewProviderCache(fp)
if err != nil {
panic(err)
}

View File

@@ -14,11 +14,6 @@ import (
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
)
const (
authHeader = "Authorization"
idTokenHeader = "X-ID-Token"
)
type Provider struct {
settings backend.DataSourceInstanceSettings
jsonData JsonData
@@ -41,9 +36,8 @@ func NewProvider(
}
type JsonData struct {
Method string `json:"httpMethod"`
OauthPassThru bool `json:"oauthPassThru"`
TimeInterval string `json:"timeInterval"`
Method string `json:"httpMethod"`
TimeInterval string `json:"timeInterval"`
}
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
@@ -53,9 +47,7 @@ func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
}
opts.Middlewares = p.middlewares()
if p.jsonData.OauthPassThru {
opts.Headers = authHeaders(headers)
}
opts.Headers = reqHeaders(headers)
// Set SigV4 service namespace
if opts.SigV4 != nil {
@@ -92,15 +84,11 @@ func (p *Provider) middlewares() []sdkhttpclient.Middleware {
return middlewares
}
func authHeaders(headers map[string]string) map[string]string {
authHeaders := make(map[string]string)
if v, ok := headers[authHeader]; ok {
authHeaders[authHeader] = v
func reqHeaders(headers map[string]string) map[string]string {
// copy to avoid changing the original map
h := make(map[string]string, len(headers))
for k, v := range headers {
h[k] = v
}
if v, ok := headers[idTokenHeader]; ok {
authHeaders[idTokenHeader] = v
}
return authHeaders
return h
}

View File

@@ -43,7 +43,7 @@ func TestGetClient(t *testing.T) {
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders")
})
t.Run("oauth pass through", func(t *testing.T) {
t.Run("extra headers", func(t *testing.T) {
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) {
tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(headers)
@@ -52,14 +52,14 @@ func TestGetClient(t *testing.T) {
require.Equal(t, headers, tc.httpProvider.opts.Headers)
})
t.Run("it only sets auth headers", func(t *testing.T) {
t.Run("it sets all headers", func(t *testing.T) {
withNonAuth := map[string]string{"X-Not-Auth": "stuff"}
tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(withNonAuth)
require.Nil(t, err)
require.Equal(t, map[string]string{}, tc.httpProvider.opts.Headers)
require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers)
})
t.Run("it does not error when headers are nil", func(t *testing.T) {
@@ -68,14 +68,6 @@ func TestGetClient(t *testing.T) {
_, err := tc.promClientProvider.GetClient(nil)
require.Nil(t, err)
})
t.Run("it does not set the headers when 'oauthPassThru' is false", func(t *testing.T) {
tc := setup()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.opts.Headers, 0)
})
})
t.Run("force get middleware", func(t *testing.T) {

View File

@@ -64,7 +64,7 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
}
p := promclient.NewProvider(settings, jsonData, httpClientProvider, plog)
pc, err := promclient.NewProviderCache(p, jsonData)
pc, err := promclient.NewProviderCache(p)
if err != nil {
return nil, err
}