Propagate all headers (#43812)

* Propagate all headers

* stable header order
This commit is contained in:
Travis Patterson
2022-01-07 12:45:26 -07:00
committed by GitHub
parent 546818819b
commit 9eb82f9fff
7 changed files with 215 additions and 70 deletions

View File

@@ -3,6 +3,7 @@ package query
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@@ -22,9 +23,20 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
) )
func ProvideService(cfg *setting.Cfg, dataSourceCache datasources.CacheService, expressionService *expr.Service, const (
pluginRequestValidator models.PluginRequestValidator, SecretsService secrets.Service, headerName = "httpHeaderName"
pluginClient plugins.Client, OAuthTokenService oauthtoken.OAuthTokenService) *Service { headerValue = "httpHeaderValue"
)
func ProvideService(
cfg *setting.Cfg,
dataSourceCache datasources.CacheService,
expressionService *expr.Service,
pluginRequestValidator models.PluginRequestValidator,
SecretsService secrets.Service,
pluginClient plugins.Client,
oAuthTokenService oauthtoken.OAuthTokenService,
) *Service {
g := &Service{ g := &Service{
cfg: cfg, cfg: cfg,
dataSourceCache: dataSourceCache, dataSourceCache: dataSourceCache,
@@ -32,14 +44,13 @@ func ProvideService(cfg *setting.Cfg, dataSourceCache datasources.CacheService,
pluginRequestValidator: pluginRequestValidator, pluginRequestValidator: pluginRequestValidator,
secretsService: SecretsService, secretsService: SecretsService,
pluginClient: pluginClient, pluginClient: pluginClient,
oAuthTokenService: OAuthTokenService, oAuthTokenService: oAuthTokenService,
log: log.New("query_data"), log: log.New("query_data"),
} }
g.log.Info("Query Service initialization") g.log.Info("Query Service initialization")
return g return g
} }
// Gateway receives data and translates it to Grafana Live publications.
type Service struct { type Service struct {
cfg *setting.Cfg cfg *setting.Cfg
dataSourceCache datasources.CacheService dataSourceCache datasources.CacheService
@@ -135,6 +146,10 @@ func (s *Service) handleQueryData(ctx context.Context, user *models.SignedInUser
} }
} }
for k, v := range customHeaders(ds.JsonData, instanceSettings.DecryptedSecureJSONData) {
req.Headers[k] = v
}
for _, q := range parsedReq.parsedQueries { for _, q := range parsedReq.parsedQueries {
req.Queries = append(req.Queries, q.query) req.Queries = append(req.Queries, q.query)
} }
@@ -152,6 +167,26 @@ type parsedRequest struct {
parsedQueries []parsedQuery parsedQueries []parsedQuery
} }
func customHeaders(jsonData *simplejson.Json, decryptedJsonData map[string]string) map[string]string {
if jsonData == nil {
return nil
}
data := jsonData.MustMap()
headers := map[string]string{}
for k := range data {
if strings.HasPrefix(k, headerName) {
if header, ok := data[k].(string); ok {
valueKey := strings.ReplaceAll(k, headerName, headerValue)
headers[header] = decryptedJsonData[valueKey]
}
}
}
return headers
}
func (s *Service) parseMetricRequest(ctx context.Context, user *models.SignedInUser, skipCache bool, reqDTO dtos.MetricRequest) (*parsedRequest, error) { func (s *Service) parseMetricRequest(ctx context.Context, user *models.SignedInUser, skipCache bool, reqDTO dtos.MetricRequest) (*parsedRequest, error) {
if len(reqDTO.Queries) == 0 { if len(reqDTO.Queries) == 0 {
return nil, NewErrBadQuery("no queries found") return nil, NewErrBadQuery("no queries found")

View File

@@ -0,0 +1,142 @@
package query_test
import (
"context"
"net/http"
"testing"
"golang.org/x/oauth2"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/query"
"github.com/grafana/grafana/pkg/services/secrets"
"github.com/stretchr/testify/require"
)
func TestQueryData(t *testing.T) {
t.Run("it attaches custom headers to the request", func(t *testing.T) {
tc := setup()
tc.dataSourceCache.ds.JsonData = simplejson.NewFromAny(map[string]interface{}{"httpHeaderName1": "foo", "httpHeaderName2": "bar"})
tc.secretService.decryptedJson = map[string]string{"httpHeaderValue1": "test-header", "httpHeaderValue2": "test-header2"}
_, err := tc.queryService.QueryData(context.Background(), nil, true, metricRequest(), false)
require.Nil(t, err)
require.Equal(t, map[string]string{"foo": "test-header", "bar": "test-header2"}, tc.pluginContext.req.Headers)
})
t.Run("it auth custom headers to the request", func(t *testing.T) {
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc := setup()
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
_, err := tc.queryService.QueryData(context.Background(), nil, true, metricRequest(), false)
require.Nil(t, err)
expected := map[string]string{
"Authorization": "Bearer access-token",
"X-ID-Token": "id-token",
}
require.Equal(t, expected, tc.pluginContext.req.Headers)
})
}
func setup() *testContext {
pc := &fakePluginClient{}
sc := &fakeSecretsService{}
dc := &fakeDataSourceCache{ds: &models.DataSource{}}
tc := &fakeOAuthTokenService{}
rv := &fakePluginRequestValidator{}
return &testContext{
pluginContext: pc,
secretService: sc,
dataSourceCache: dc,
oauthTokenService: tc,
pluginRequestValidator: rv,
queryService: query.ProvideService(nil, dc, nil, rv, sc, pc, tc),
}
}
type testContext struct {
pluginContext *fakePluginClient
secretService *fakeSecretsService
dataSourceCache *fakeDataSourceCache
oauthTokenService *fakeOAuthTokenService
pluginRequestValidator *fakePluginRequestValidator
queryService *query.Service
}
func metricRequest() dtos.MetricRequest {
q, _ := simplejson.NewJson([]byte(`{"datasourceId":1}`))
return dtos.MetricRequest{
From: "",
To: "",
Queries: []*simplejson.Json{q},
Debug: false,
}
}
type fakePluginRequestValidator struct {
err error
}
func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request) error {
return rv.err
}
type fakeOAuthTokenService struct {
passThruEnabled bool
token *oauth2.Token
}
func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *models.SignedInUser) *oauth2.Token {
return ts.token
}
func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*models.DataSource) bool {
return ts.passThruEnabled
}
type fakeSecretsService struct {
secrets.Service
decryptedJson map[string]string
}
func (s *fakeSecretsService) DecryptJsonData(ctx context.Context, sjd map[string][]byte) (map[string]string, error) {
return s.decryptedJson, nil
}
type fakeDataSourceCache struct {
ds *models.DataSource
}
func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *models.SignedInUser, skipCache bool) (*models.DataSource, error) {
return c.ds, nil
}
func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *models.SignedInUser, skipCache bool) (*models.DataSource, error) {
return c.ds, nil
}
type fakePluginClient struct {
plugins.Client
req *backend.QueryDataRequest
}
func (c *fakePluginClient) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
c.req = req
return nil, nil
}

View File

@@ -1,25 +1,23 @@
package promclient package promclient
import ( import (
"sort"
"strings"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
) )
const (
noPassThrough = "no-pass-through"
)
type ProviderCache struct { type ProviderCache struct {
provider promClientProvider provider promClientProvider
cache *lru.Cache cache *lru.Cache
jsonData JsonData
} }
type promClientProvider interface { type promClientProvider interface {
GetClient(map[string]string) (apiv1.API, error) GetClient(map[string]string) (apiv1.API, error)
} }
func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error) { func NewProviderCache(p promClientProvider) (*ProviderCache, error) {
cache, err := lru.New(500) cache, err := lru.New(500)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -28,7 +26,6 @@ func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error)
return &ProviderCache{ return &ProviderCache{
provider: p, provider: p,
cache: cache, cache: cache,
jsonData: jd,
}, nil }, nil
} }
@@ -48,8 +45,12 @@ func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error)
} }
func (c *ProviderCache) key(headers map[string]string) string { func (c *ProviderCache) key(headers map[string]string) string {
if c.jsonData.OauthPassThru { vals := make([]string, len(headers))
return headers[authHeader] + headers[idTokenHeader] var i int
for _, v := range headers {
vals[i] = v
i++
} }
return noPassThrough sort.Strings(vals)
return strings.Join(vals, "")
} }

View File

@@ -16,7 +16,7 @@ import (
func TestCache_GetClient(t *testing.T) { func TestCache_GetClient(t *testing.T) {
t.Run("it caches the client for a set of auth headers", func(t *testing.T) { t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
tc := setupCacheContext(true) tc := setupCacheContext()
c, err := tc.providerCache.GetClient(headers) c, err := tc.providerCache.GetClient(headers)
require.Nil(t, err) require.Nil(t, err)
@@ -28,8 +28,8 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 1, tc.clientProvider.numCalls) require.Equal(t, 1, tc.clientProvider.numCalls)
}) })
t.Run("it returns different clients when the auth headers differ", func(t *testing.T) { t.Run("it returns different clients when the headers differ", func(t *testing.T) {
tc := setupCacheContext(true) tc := setupCacheContext()
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
@@ -43,10 +43,10 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 2, tc.clientProvider.numCalls) require.Equal(t, 2, tc.clientProvider.numCalls)
}) })
t.Run("it always returns from the cache when 'oauthPassThru' not set", func(t *testing.T) { t.Run("it returns from the cache when headers are the same", func(t *testing.T) {
tc := setupCacheContext(false) tc := setupCacheContext()
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
c, err := tc.providerCache.GetClient(h1) c, err := tc.providerCache.GetClient(h1)
require.Nil(t, err) require.Nil(t, err)
@@ -58,21 +58,8 @@ func TestCache_GetClient(t *testing.T) {
require.Equal(t, 1, tc.clientProvider.numCalls) require.Equal(t, 1, tc.clientProvider.numCalls)
}) })
t.Run("it only accounts for auth headers", func(t *testing.T) {
tc := setupCacheContext(true)
c, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "stuff"})
require.Nil(t, err)
c2, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "other-stuff"})
require.Nil(t, err)
require.Equal(t, c, c2)
require.Equal(t, 1, tc.clientProvider.numCalls)
})
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) { t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
tc := setupCacheContext(true) tc := setupCacheContext()
tc.clientProvider.errors <- errors.New("something bad") tc.clientProvider.errors <- errors.New("something bad")
_, err := tc.providerCache.GetClient(headers) _, err := tc.providerCache.GetClient(headers)
@@ -91,9 +78,9 @@ type cacheTestContext struct {
clientProvider *fakePromClientProvider clientProvider *fakePromClientProvider
} }
func setupCacheContext(oauthPassTrough bool) *cacheTestContext { func setupCacheContext() *cacheTestContext {
fp := newFakePromClientProvider() fp := newFakePromClientProvider()
p, err := promclient.NewProviderCache(fp, promclient.JsonData{OauthPassThru: oauthPassTrough}) p, err := promclient.NewProviderCache(fp)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -14,11 +14,6 @@ import (
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
) )
const (
authHeader = "Authorization"
idTokenHeader = "X-ID-Token"
)
type Provider struct { type Provider struct {
settings backend.DataSourceInstanceSettings settings backend.DataSourceInstanceSettings
jsonData JsonData jsonData JsonData
@@ -41,9 +36,8 @@ func NewProvider(
} }
type JsonData struct { type JsonData struct {
Method string `json:"httpMethod"` Method string `json:"httpMethod"`
OauthPassThru bool `json:"oauthPassThru"` TimeInterval string `json:"timeInterval"`
TimeInterval string `json:"timeInterval"`
} }
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) { func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
@@ -53,9 +47,7 @@ func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) {
} }
opts.Middlewares = p.middlewares() opts.Middlewares = p.middlewares()
if p.jsonData.OauthPassThru { opts.Headers = reqHeaders(headers)
opts.Headers = authHeaders(headers)
}
// Set SigV4 service namespace // Set SigV4 service namespace
if opts.SigV4 != nil { if opts.SigV4 != nil {
@@ -92,15 +84,11 @@ func (p *Provider) middlewares() []sdkhttpclient.Middleware {
return middlewares return middlewares
} }
func authHeaders(headers map[string]string) map[string]string { func reqHeaders(headers map[string]string) map[string]string {
authHeaders := make(map[string]string) // copy to avoid changing the original map
if v, ok := headers[authHeader]; ok { h := make(map[string]string, len(headers))
authHeaders[authHeader] = v for k, v := range headers {
h[k] = v
} }
return h
if v, ok := headers[idTokenHeader]; ok {
authHeaders[idTokenHeader] = v
}
return authHeaders
} }

View File

@@ -43,7 +43,7 @@ func TestGetClient(t *testing.T) {
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders") require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders")
}) })
t.Run("oauth pass through", func(t *testing.T) { t.Run("extra headers", func(t *testing.T) {
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) { t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) {
tc := setup(`{"oauthPassThru":true}`) tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(headers) _, err := tc.promClientProvider.GetClient(headers)
@@ -52,14 +52,14 @@ func TestGetClient(t *testing.T) {
require.Equal(t, headers, tc.httpProvider.opts.Headers) require.Equal(t, headers, tc.httpProvider.opts.Headers)
}) })
t.Run("it only sets auth headers", func(t *testing.T) { t.Run("it sets all headers", func(t *testing.T) {
withNonAuth := map[string]string{"X-Not-Auth": "stuff"} withNonAuth := map[string]string{"X-Not-Auth": "stuff"}
tc := setup(`{"oauthPassThru":true}`) tc := setup(`{"oauthPassThru":true}`)
_, err := tc.promClientProvider.GetClient(withNonAuth) _, err := tc.promClientProvider.GetClient(withNonAuth)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, map[string]string{}, tc.httpProvider.opts.Headers) require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers)
}) })
t.Run("it does not error when headers are nil", func(t *testing.T) { t.Run("it does not error when headers are nil", func(t *testing.T) {
@@ -68,14 +68,6 @@ func TestGetClient(t *testing.T) {
_, err := tc.promClientProvider.GetClient(nil) _, err := tc.promClientProvider.GetClient(nil)
require.Nil(t, err) require.Nil(t, err)
}) })
t.Run("it does not set the headers when 'oauthPassThru' is false", func(t *testing.T) {
tc := setup()
_, err := tc.promClientProvider.GetClient(headers)
require.Nil(t, err)
require.Len(t, tc.httpProvider.opts.Headers, 0)
})
}) })
t.Run("force get middleware", func(t *testing.T) { t.Run("force get middleware", func(t *testing.T) {

View File

@@ -64,7 +64,7 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
} }
p := promclient.NewProvider(settings, jsonData, httpClientProvider, plog) p := promclient.NewProvider(settings, jsonData, httpClientProvider, plog)
pc, err := promclient.NewProviderCache(p, jsonData) pc, err := promclient.NewProviderCache(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }