From c35c689a96549dbd6ed25492c9c482840edc86c1 Mon Sep 17 00:00:00 2001 From: Marcus Efraimsson Date: Wed, 21 Dec 2022 13:25:58 +0100 Subject: [PATCH] Plugins: Automatically forward plugin request HTTP headers in outgoing HTTP requests (#60417) Automatically forward core plugin request HTTP headers in outgoing HTTP requests. Core datasource plugin authors don't have to specifically handle forwarding of HTTP headers, e.g. do not have to "hardcode" the header-names in the datasource plugin, if not having custom needs. Fixes #57065 --- .../delete_headers_middleware.go | 27 - .../delete_headers_middleware_test.go | 66 --- .../set_headers_middleware.go | 31 -- .../set_headers_middleware_test.go | 66 --- pkg/plugins/manager/client/client.go | 1 + pkg/services/alerting/conditions/query.go | 5 +- pkg/services/ngalert/eval/eval.go | 4 +- pkg/services/ngalert/models/constants.go | 21 + .../clear_auth_headers_middleware.go | 20 +- .../clear_auth_headers_middleware_test.go | 112 +--- .../clientmiddleware/cookies_middleware.go | 49 +- .../cookies_middleware_test.go | 77 +-- .../clientmiddleware/httpclient_middleware.go | 108 ++++ .../httpclient_middleware_test.go | 289 ++++++++++ .../clientmiddleware/oauthtoken_middleware.go | 36 +- .../oauthtoken_middleware_test.go | 50 -- .../user_header_middleware.go | 32 +- .../user_header_middleware_test.go | 50 -- .../pluginsintegration/pluginsintegration.go | 2 + .../backendplugin/backendplugin_test.go | 508 ++++++++++-------- pkg/tsdb/cloudwatch/cloudwatch.go | 3 +- pkg/tsdb/cloudwatch/cloudwatch_test.go | 5 +- pkg/tsdb/loki/api.go | 29 +- pkg/tsdb/loki/api_mock.go | 4 +- pkg/tsdb/loki/auth_test.go | 201 ------- pkg/tsdb/loki/loki.go | 45 +- pkg/tsdb/loki/loki_test.go | 74 --- pkg/tsdb/prometheus/client/client.go | 27 +- pkg/tsdb/prometheus/client/client_test.go | 6 +- pkg/tsdb/prometheus/prometheus_test.go | 7 +- pkg/tsdb/prometheus/querydata/request.go | 33 +- pkg/tsdb/prometheus/querydata/request_test.go | 22 - 32 files changed, 816 insertions(+), 1194 deletions(-) delete mode 100644 pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go delete mode 100644 pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go delete mode 100644 pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go delete mode 100644 pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go create mode 100644 pkg/services/ngalert/models/constants.go create mode 100644 pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware.go create mode 100644 pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware_test.go delete mode 100644 pkg/tsdb/loki/auth_test.go delete mode 100644 pkg/tsdb/loki/loki_test.go diff --git a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go deleted file mode 100644 index 3bc3bcdc824..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go +++ /dev/null @@ -1,27 +0,0 @@ -package httpclientprovider - -import ( - "net/http" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" -) - -const DeleteHeadersMiddlewareName = "delete-headers" - -// DeleteHeadersMiddleware middleware that delete headers on the outgoing -// request if header names provided. -func DeleteHeadersMiddleware(headerNames ...string) httpclient.Middleware { - return httpclient.NamedMiddlewareFunc(DeleteHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { - if len(headerNames) == 0 { - return next - } - - return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - for _, k := range headerNames { - req.Header.Del(k) - } - - return next.RoundTrip(req) - }) - }) -} diff --git a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go deleted file mode 100644 index 01ec6a0b08e..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package httpclientprovider - -import ( - "net/http" - "testing" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/stretchr/testify/require" -) - -func TestDeleteHeadersMiddleware(t *testing.T) { - t.Run("Without headerNames should return next http.RoundTripper", func(t *testing.T) { - ctx := &testContext{} - finalRoundTripper := ctx.createRoundTripper("finalrt") - var headerNames []string - mw := DeleteHeadersMiddleware(headerNames...) - rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) - require.NotNil(t, rt) - middlewareName, ok := mw.(httpclient.MiddlewareName) - require.True(t, ok) - require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName()) - - req, err := http.NewRequest(http.MethodGet, "http://", nil) - require.NoError(t, err) - res, err := rt.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, res) - if res.Body != nil { - require.NoError(t, res.Body.Close()) - } - require.Len(t, ctx.callChain, 1) - require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain) - }) - - t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) { - ctx := &testContext{} - finalRoundTripper := ctx.createRoundTripper("final") - headerNames := []string{"X-Header-B", "X-Header-C"} - mw := DeleteHeadersMiddleware(headerNames...) - rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) - require.NotNil(t, rt) - middlewareName, ok := mw.(httpclient.MiddlewareName) - require.True(t, ok) - require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName()) - - req, err := http.NewRequest(http.MethodGet, "http://", nil) - require.NoError(t, err) - req.Header.Set("X-Header-A", "a") - req.Header.Set("X-Header-B", "b") - req.Header.Set("X-Header-C", "c") - req.Header.Set("X-Header-D", "d") - res, err := rt.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, res) - if res.Body != nil { - require.NoError(t, res.Body.Close()) - } - require.Len(t, ctx.callChain, 1) - require.ElementsMatch(t, []string{"final"}, ctx.callChain) - - require.Equal(t, "a", req.Header.Get("X-Header-A")) - require.Empty(t, req.Header.Get("X-Header-B")) - require.Empty(t, req.Header.Get("X-Header-C")) - require.Equal(t, "d", req.Header.Get("X-Header-D")) - }) -} diff --git a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go deleted file mode 100644 index 787cada75af..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go +++ /dev/null @@ -1,31 +0,0 @@ -package httpclientprovider - -import ( - "net/http" - "net/textproto" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" -) - -const SetHeadersMiddlewareName = "set-headers" - -// SetHeadersMiddleware middleware that sets headers on the outgoing -// request if headers provided. -// If the request already contains any of the headers provided, they -// will be overwritten. -func SetHeadersMiddleware(headers http.Header) httpclient.Middleware { - return httpclient.NamedMiddlewareFunc(SetHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { - if len(headers) == 0 { - return next - } - - return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - for k, v := range headers { - canonicalKey := textproto.CanonicalMIMEHeaderKey(k) - req.Header[canonicalKey] = v - } - - return next.RoundTrip(req) - }) - }) -} diff --git a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go deleted file mode 100644 index 0970d23cbcc..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package httpclientprovider - -import ( - "net/http" - "testing" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/stretchr/testify/require" -) - -func TestSetHeadersMiddleware(t *testing.T) { - t.Run("Without headers set should return next http.RoundTripper", func(t *testing.T) { - ctx := &testContext{} - finalRoundTripper := ctx.createRoundTripper("finalrt") - var headers http.Header - mw := SetHeadersMiddleware(headers) - rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) - require.NotNil(t, rt) - middlewareName, ok := mw.(httpclient.MiddlewareName) - require.True(t, ok) - require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName()) - - req, err := http.NewRequest(http.MethodGet, "http://", nil) - require.NoError(t, err) - res, err := rt.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, res) - if res.Body != nil { - require.NoError(t, res.Body.Close()) - } - require.Len(t, ctx.callChain, 1) - require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain) - }) - - t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) { - ctx := &testContext{} - finalRoundTripper := ctx.createRoundTripper("final") - headers := http.Header{ - "X-Header-A": []string{"value a"}, - "X-Header-B": []string{"value b"}, - "X-HEader-C": []string{"value c"}, - } - mw := SetHeadersMiddleware(headers) - rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) - require.NotNil(t, rt) - middlewareName, ok := mw.(httpclient.MiddlewareName) - require.True(t, ok) - require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName()) - - req, err := http.NewRequest(http.MethodGet, "http://", nil) - require.NoError(t, err) - req.Header.Set("X-Header-B", "d") - res, err := rt.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, res) - if res.Body != nil { - require.NoError(t, res.Body.Close()) - } - require.Len(t, ctx.callChain, 1) - require.ElementsMatch(t, []string{"final"}, ctx.callChain) - - require.Equal(t, "value a", req.Header.Get("X-Header-A")) - require.Equal(t, "value b", req.Header.Get("X-Header-B")) - require.Equal(t, "value c", req.Header.Get("X-Header-C")) - }) -} diff --git a/pkg/plugins/manager/client/client.go b/pkg/plugins/manager/client/client.go index cc6185ea3b9..b9262ac41d7 100644 --- a/pkg/plugins/manager/client/client.go +++ b/pkg/plugins/manager/client/client.go @@ -255,6 +255,7 @@ var hopHeaders = []string{ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 "Transfer-Encoding", "Upgrade", + "User-Agent", } // removeHopByHopHeaders removes hop-by-hop headers. Especially diff --git a/pkg/services/alerting/conditions/query.go b/pkg/services/alerting/conditions/query.go index dabb46e5eb8..f14416c3271 100644 --- a/pkg/services/alerting/conditions/query.go +++ b/pkg/services/alerting/conditions/query.go @@ -18,6 +18,7 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/alerting" "github.com/grafana/grafana/pkg/services/datasources" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" ) func init() { @@ -276,8 +277,8 @@ func (c *QueryCondition) getRequestForAlertRule(datasource *datasources.DataSour }, }, Headers: map[string]string{ - "FromAlert": "true", - "X-Cache-Skip": "true", + ngalertmodels.FromAlertHeaderName: "true", + ngalertmodels.CacheSkipHeaderName: "true", }, Debug: debug, } diff --git a/pkg/services/ngalert/eval/eval.go b/pkg/services/ngalert/eval/eval.go index 726d47f6224..3f2f982b0ea 100644 --- a/pkg/services/ngalert/eval/eval.go +++ b/pkg/services/ngalert/eval/eval.go @@ -223,9 +223,9 @@ func buildDatasourceHeaders(ctx EvaluationContext) map[string]string { // Note: The spelling of this headers is intentionally degenerate from the others for compatibility reasons. // When sent over a network, the key of this header is canonicalized to "Fromalert". // However, some datasources still compare against the string "FromAlert". - "FromAlert": "true", + models.FromAlertHeaderName: "true", - "X-Cache-Skip": "true", + models.CacheSkipHeaderName: "true", } key, ok := models.RuleKeyFromContext(ctx.Ctx) diff --git a/pkg/services/ngalert/models/constants.go b/pkg/services/ngalert/models/constants.go new file mode 100644 index 00000000000..f6b52d6150c --- /dev/null +++ b/pkg/services/ngalert/models/constants.go @@ -0,0 +1,21 @@ +package models + +const ( + // FromAlertHeaderName name of header added to datasource query requests + // to denote request is originating from Grafana Alerting. + // + // Data sources might check this in query method as sometimes alerting + // needs special considerations. + // Several existing systems also compare against the value of this header. + // Altering this constitutes a breaking change. + // + // Note: The spelling of this headers is intentionally degenerate from the + // others for compatibility reasons. When sent over a network, the key of + // this header is canonicalized to "Fromalert". + // However, some datasources still compare against the string "FromAlert". + FromAlertHeaderName = "FromAlert" + + // CacheSkipHeaderName name of header added to datasource query requests + // to denote request should not be cached. + CacheSkipHeaderName = "X-Cache-Skip" +) diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go index 10a7c98be84..cf2d37db6f0 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go @@ -4,8 +4,6 @@ import ( "context" "github.com/grafana/grafana-plugin-sdk-go/backend" - sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/services/contexthandler" ) @@ -25,19 +23,19 @@ type ClearAuthHeadersMiddleware struct { next plugins.Client } -func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, pCtx backend.PluginContext, req interface{}) context.Context { +func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, h backend.ForwardHTTPHeaders) { reqCtx := contexthandler.FromContext(ctx) // if no HTTP request context skip middleware - if req == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil { - return ctx + if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil { + return } list := contexthandler.AuthHTTPHeaderListFromContext(ctx) if list != nil { - ctx = sdkhttpclient.WithContextualMiddleware(ctx, httpclientprovider.DeleteHeadersMiddleware(list.Items...)) + for _, k := range list.Items { + h.DeleteHTTPHeader(k) + } } - - return ctx } func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { @@ -45,7 +43,7 @@ func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend return m.next.QueryData(ctx, req) } - ctx = m.clearHeaders(ctx, req.PluginContext, req) + m.clearHeaders(ctx, req) return m.next.QueryData(ctx, req) } @@ -55,7 +53,7 @@ func (m *ClearAuthHeadersMiddleware) CallResource(ctx context.Context, req *back return m.next.CallResource(ctx, req, sender) } - ctx = m.clearHeaders(ctx, req.PluginContext, req) + m.clearHeaders(ctx, req) return m.next.CallResource(ctx, req, sender) } @@ -65,7 +63,7 @@ func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backe return m.next.CheckHealth(ctx, req) } - ctx = m.clearHeaders(ctx, req.PluginContext, req) + m.clearHeaders(ctx, req) return m.next.CheckHealth(ctx, req) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go index 195b3bd75dd..6a4975203ca 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go @@ -1,14 +1,10 @@ package clientmiddleware import ( - "bytes" - "io" "net/http" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/user" @@ -42,9 +38,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { @@ -55,9 +48,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { @@ -68,9 +58,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 0) }) }) @@ -92,9 +79,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { @@ -105,9 +89,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { @@ -118,9 +99,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 0) }) }) }) @@ -155,18 +133,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) }) t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { @@ -177,18 +144,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader]) }) t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { @@ -199,18 +155,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) }) }) @@ -220,12 +165,12 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), ) - const customHeader = "X-Custom" + const customHeader = "x-Custom" req.Header.Set(customHeader, "val") ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) req = req.WithContext(ctx) - const otherHeader = "X-Other" + const otherHeader = "x-Other" req.Header.Set(otherHeader, "test") pluginCtx := backend.PluginContext{ @@ -240,18 +185,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) }) t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { @@ -262,18 +196,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader]) }) t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { @@ -284,27 +207,8 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Empty(t, reqClone.Header[customHeader]) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) }) }) }) } - -var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Request: req, - Body: io.NopCloser(bytes.NewBufferString("")), - }, nil -}) diff --git a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go index 1c0824cc084..a6c8172779d 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go @@ -4,9 +4,7 @@ import ( "context" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana/pkg/components/simplejson" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/datasources" @@ -16,8 +14,8 @@ import ( const cookieHeaderName = "Cookie" // NewCookiesMiddleware creates a new plugins.ClientMiddleware that will -// forward incoming HTTP request Cookies to outgoing plugins.Client and -// HTTP requests if the datasource has enabled forwarding of cookies (keepCookies). +// forward incoming HTTP request Cookies to outgoing plugins.Client requests +// if the datasource has enabled forwarding of cookies (keepCookies). func NewCookiesMiddleware(skipCookiesNames []string) plugins.ClientMiddleware { return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { return &CookiesMiddleware{ @@ -32,17 +30,17 @@ type CookiesMiddleware struct { skipCookiesNames []string } -func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) { +func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) error { reqCtx := contexthandler.FromContext(ctx) // if request not for a datasource or no HTTP request context skip middleware if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil { - return ctx, nil + return nil } settings := pCtx.DataSourceInstanceSettings jsonDataBytes, err := simplejson.NewJson(settings.JSONData) if err != nil { - return ctx, err + return err } ds := &datasources.DataSource{ @@ -54,20 +52,29 @@ func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.Plugi proxyutil.ClearCookieHeader(reqCtx.Req, ds.AllowedCookies(), m.skipCookiesNames) - if cookieStr := reqCtx.Req.Header.Get(cookieHeaderName); cookieStr != "" { - switch t := req.(type) { - case *backend.QueryDataRequest: + cookieStr := reqCtx.Req.Header.Get(cookieHeaderName) + switch t := req.(type) { + case *backend.QueryDataRequest: + if cookieStr == "" { + delete(t.Headers, cookieHeaderName) + } else { t.Headers[cookieHeaderName] = cookieStr - case *backend.CheckHealthRequest: + } + case *backend.CheckHealthRequest: + if cookieStr == "" { + delete(t.Headers, cookieHeaderName) + } else { t.Headers[cookieHeaderName] = cookieStr - case *backend.CallResourceRequest: + } + case *backend.CallResourceRequest: + if cookieStr == "" { + delete(t.Headers, cookieHeaderName) + } else { t.Headers[cookieHeaderName] = []string{cookieStr} } } - ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(reqCtx.Req.Cookies(), ds.AllowedCookies(), m.skipCookiesNames)) - - return ctx, nil + return nil } func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { @@ -75,12 +82,12 @@ func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDat return m.next.QueryData(ctx, req) } - newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + err := m.applyCookies(ctx, req.PluginContext, req) if err != nil { return nil, err } - return m.next.QueryData(newCtx, req) + return m.next.QueryData(ctx, req) } func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { @@ -88,12 +95,12 @@ func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallR return m.next.CallResource(ctx, req, sender) } - newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + err := m.applyCookies(ctx, req.PluginContext, req) if err != nil { return err } - return m.next.CallResource(newCtx, req, sender) + return m.next.CallResource(ctx, req, sender) } func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { @@ -101,12 +108,12 @@ func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckH return m.next.CheckHealth(ctx, req) } - newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + err := m.applyCookies(ctx, req.PluginContext, req) if err != nil { return nil, err } - return m.next.CheckHealth(newCtx, req) + return m.next.CheckHealth(ctx, req) } func (m *CookiesMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { diff --git a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go index a1493ea9742..473de641618 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go @@ -6,8 +6,6 @@ import ( "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" "github.com/grafana/grafana/pkg/services/user" "github.com/stretchr/testify/require" @@ -54,39 +52,19 @@ func TestCookiesMiddleware(t *testing.T) { require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) }) t.Run("Should not forward cookies when calling CallResource", func(t *testing.T) { - err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + pReq := &backend.CallResourceRequest{ PluginContext: pluginCtx, Headers: map[string][]string{otherHeader: {"test"}}, - }, nopCallResourceSender) + } + pReq.Headers[backend.CookiesHeaderName] = []string{req.Header.Get(backend.CookiesHeaderName)} + err = cdt.Decorator.CallResource(req.Context(), pReq, nopCallResourceSender) require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) }) t.Run("Should not forward cookies when calling CheckHealth", func(t *testing.T) { @@ -98,17 +76,6 @@ func TestCookiesMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 1) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) }) }) @@ -157,18 +124,6 @@ func TestCookiesMiddleware(t *testing.T) { require.Len(t, cdt.QueryDataReq.Headers, 2) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) require.EqualValues(t, "cookie2=", cdt.QueryDataReq.Headers[cookieHeaderName]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 2) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) }) t.Run("Should forward cookies when calling CallResource", func(t *testing.T) { @@ -182,18 +137,6 @@ func TestCookiesMiddleware(t *testing.T) { require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) require.Len(t, cdt.CallResourceReq.Headers[cookieHeaderName], 1) require.EqualValues(t, "cookie2=", cdt.CallResourceReq.Headers[cookieHeaderName][0]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 2) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) }) t.Run("Should forward cookies when calling CheckHealth", func(t *testing.T) { @@ -206,18 +149,6 @@ func TestCookiesMiddleware(t *testing.T) { require.Len(t, cdt.CheckHealthReq.Headers, 2) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) require.EqualValues(t, "cookie2=", cdt.CheckHealthReq.Headers[cookieHeaderName]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 2) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) }) }) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware.go new file mode 100644 index 00000000000..3050395c050 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware.go @@ -0,0 +1,108 @@ +package clientmiddleware + +import ( + "context" + "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/plugins" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" +) + +const forwardPluginRequestHTTPHeaders = "forward-plugin-request-http-headers" + +// NewHTTPClientMiddleware creates a new plugins.ClientMiddleware +// that will forward plugin request headers as outgoing HTTP headers. +func NewHTTPClientMiddleware() plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &HTTPClientMiddleware{ + next: next, + } + }) +} + +type HTTPClientMiddleware struct { + next plugins.Client +} + +func (m *HTTPClientMiddleware) applyHeaders(ctx context.Context, pReq interface{}) context.Context { + if pReq == nil { + return ctx + } + + mw := httpclient.NamedMiddlewareFunc(forwardPluginRequestHTTPHeaders, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { + return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + switch t := pReq.(type) { + case *backend.QueryDataRequest: + if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists { + req.Header.Set(ngalertmodels.FromAlertHeaderName, val) + } + case *backend.CallResourceRequest: + if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists { + req.Header.Set(ngalertmodels.FromAlertHeaderName, val[0]) + } + case *backend.CheckHealthRequest: + if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists { + req.Header.Set(ngalertmodels.FromAlertHeaderName, val) + } + } + + if h, ok := pReq.(backend.ForwardHTTPHeaders); ok { + for k, v := range h.GetHTTPHeaders() { + req.Header[k] = v + } + } + + return next.RoundTrip(req) + }) + }) + + return httpclient.WithContextualMiddleware(ctx, mw) +} + +func (m *HTTPClientMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return m.next.QueryData(ctx, req) + } + + ctx = m.applyHeaders(ctx, req) + + return m.next.QueryData(ctx, req) +} + +func (m *HTTPClientMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return m.next.CallResource(ctx, req, sender) + } + + ctx = m.applyHeaders(ctx, req) + + return m.next.CallResource(ctx, req, sender) +} + +func (m *HTTPClientMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return m.next.CheckHealth(ctx, req) + } + + ctx = m.applyHeaders(ctx, req) + + return m.next.CheckHealth(ctx, req) +} + +func (m *HTTPClientMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + return m.next.CollectMetrics(ctx, req) +} + +func (m *HTTPClientMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + return m.next.SubscribeStream(ctx, req) +} + +func (m *HTTPClientMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + return m.next.PublishStream(ctx, req) +} + +func (m *HTTPClientMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + return m.next.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware_test.go new file mode 100644 index 00000000000..691a96d1b4e --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/httpclient_middleware_test.go @@ -0,0 +1,289 @@ +package clientmiddleware + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/util/proxyutil" + "github.com/stretchr/testify/require" +) + +func TestHTTPClientMiddleware(t *testing.T) { + const otherHeader = "test" + + t.Run("When no http headers in plugin request", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + t.Run("And requests are for a datasource", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewHTTPClientMiddleware()), + ) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } + + t.Run("Should not forward headers when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "val"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + + t.Run("Should not forward headers when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"val"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + + t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "val"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + }) + + t.Run("And requests are for an app", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewHTTPClientMiddleware()), + ) + + pluginCtx := backend.PluginContext{ + AppInstanceSettings: &backend.AppInstanceSettings{}, + } + + t.Run("Should not forward headers when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 0) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + + t.Run("Should not forward headers when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 0) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + + t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 0) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 0) + }) + }) + }) + + t.Run("When HTTP headers in plugin request", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + headers := map[string]string{ + ngalertmodels.FromAlertHeaderName: "true", + backend.OAuthIdentityTokenHeaderName: "bearer token", + backend.OAuthIdentityIDTokenHeaderName: "id-token", + "http_" + proxyutil.UserHeaderName: "uname", + backend.CookiesHeaderName: "cookie1=; cookie2=; cookie3=", + otherHeader: "val", + } + + crHeaders := map[string][]string{} + for k, v := range headers { + crHeaders[k] = []string{v} + } + + t.Run("And requests are for a datasource", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewHTTPClientMiddleware()), + ) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } + + t.Run("Should forward headers when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: headers, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 6) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 5) + require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName)) + require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName)) + require.Len(t, reqClone.Cookies(), 3) + require.Equal(t, "cookie1", reqClone.Cookies()[0].Name) + require.Equal(t, "cookie2", reqClone.Cookies()[1].Name) + require.Equal(t, "cookie3", reqClone.Cookies()[2].Name) + }) + + t.Run("Should forward headers when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: crHeaders, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 6) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 5) + require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName)) + require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName)) + require.Len(t, reqClone.Cookies(), 3) + require.Equal(t, "cookie1", reqClone.Cookies()[0].Name) + require.Equal(t, "cookie2", reqClone.Cookies()[1].Name) + require.Equal(t, "cookie3", reqClone.Cookies()[2].Name) + }) + + t.Run("Should forward headers when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: headers, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 6) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 5) + require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName)) + require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName)) + require.Len(t, reqClone.Cookies(), 3) + require.Equal(t, "cookie1", reqClone.Cookies()[0].Name) + require.Equal(t, "cookie2", reqClone.Cookies()[1].Name) + require.Equal(t, "cookie3", reqClone.Cookies()[2].Name) + }) + }) + }) +} + +var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Request: req, + Body: io.NopCloser(bytes.NewBufferString("")), + }, nil +}) diff --git a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go index bf8cbaac4f8..69d19a118fa 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go @@ -3,12 +3,9 @@ package clientmiddleware import ( "context" "fmt" - "net/http" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana/pkg/components/simplejson" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/datasources" @@ -16,8 +13,8 @@ import ( ) // NewOAuthTokenMiddleware creates a new plugins.ClientMiddleware that will -// set OAuth token headers on outgoing plugins.Client and HTTP requests if -// the datasource has enabled Forward OAuth Identity (oauthPassThru). +// set OAuth token headers on outgoing plugins.Client requests if the +// datasource has enabled Forward OAuth Identity (oauthPassThru). func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plugins.ClientMiddleware { return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { return &OAuthTokenMiddleware{ @@ -37,17 +34,17 @@ type OAuthTokenMiddleware struct { next plugins.Client } -func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) { +func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) error { reqCtx := contexthandler.FromContext(ctx) // if request not for a datasource or no HTTP request context skip middleware if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil { - return ctx, nil + return nil } settings := pCtx.DataSourceInstanceSettings jsonDataBytes, err := simplejson.NewJson(settings.JSONData) if err != nil { - return ctx, err + return err } ds := &datasources.DataSource{ @@ -84,19 +81,10 @@ func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.Plug t.Headers[idTokenHeaderName] = []string{idTokenHeader} } } - - httpHeaders := http.Header{} - httpHeaders.Set(tokenHeaderName, authorizationHeader) - - if idTokenHeader != "" { - httpHeaders.Set(idTokenHeaderName, idTokenHeader) - } - - ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.SetHeadersMiddleware(httpHeaders)) } } - return ctx, nil + return nil } func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { @@ -104,12 +92,12 @@ func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.Query return m.next.QueryData(ctx, req) } - newCtx, err := m.applyToken(ctx, req.PluginContext, req) + err := m.applyToken(ctx, req.PluginContext, req) if err != nil { return nil, err } - return m.next.QueryData(newCtx, req) + return m.next.QueryData(ctx, req) } func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { @@ -117,12 +105,12 @@ func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.Ca return m.next.CallResource(ctx, req, sender) } - newCtx, err := m.applyToken(ctx, req.PluginContext, req) + err := m.applyToken(ctx, req.PluginContext, req) if err != nil { return err } - return m.next.CallResource(newCtx, req, sender) + return m.next.CallResource(ctx, req, sender) } func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { @@ -130,12 +118,12 @@ func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.Che return m.next.CheckHealth(ctx, req) } - newCtx, err := m.applyToken(ctx, req.PluginContext, req) + err := m.applyToken(ctx, req.PluginContext, req) if err != nil { return nil, err } - return m.next.CheckHealth(newCtx, req) + return m.next.CheckHealth(ctx, req) } func (m *OAuthTokenMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { diff --git a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go index 6a22a8c8921..ad79e5fd8db 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go @@ -6,8 +6,6 @@ import ( "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/user" @@ -49,9 +47,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not forward OAuth Identity when calling CallResource", func(t *testing.T) { @@ -63,9 +58,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 0) }) t.Run("Should not forward OAuth Identity when calling CheckHealth", func(t *testing.T) { @@ -77,9 +69,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 0) }) }) @@ -125,19 +114,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName]) require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 3) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) - require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) }) t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) { @@ -153,19 +129,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0]) require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1) require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 3) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) - require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) }) t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) { @@ -179,19 +142,6 @@ func TestOAuthTokenMiddleware(t *testing.T) { require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName]) require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName]) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - - reqClone := req.Clone(req.Context()) - res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - require.Len(t, reqClone.Header, 3) - require.Equal(t, "test", reqClone.Header.Get(otherHeader)) - require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) - require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) }) }) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go index 65cd4266cc4..ddce870d0b2 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go @@ -2,19 +2,15 @@ package clientmiddleware import ( "context" - "net/http" "github.com/grafana/grafana-plugin-sdk-go/backend" - sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/util/proxyutil" ) // NewUserHeaderMiddleware creates a new plugins.ClientMiddleware that will -// populate the X-Grafana-User header on outgoing plugins.Client and HTTP -// requests. +// populate the X-Grafana-User header on outgoing plugins.Client requests. func NewUserHeaderMiddleware() plugins.ClientMiddleware { return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { return &UserHeaderMiddleware{ @@ -27,33 +23,17 @@ type UserHeaderMiddleware struct { next plugins.Client } -func (m *UserHeaderMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, h backend.ForwardHTTPHeaders) context.Context { +func (m *UserHeaderMiddleware) applyUserHeader(ctx context.Context, h backend.ForwardHTTPHeaders) { reqCtx := contexthandler.FromContext(ctx) // if no HTTP request context skip middleware if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil { - return ctx + return } h.DeleteHTTPHeader(proxyutil.UserHeaderName) if !reqCtx.IsAnonymous { h.SetHTTPHeader(proxyutil.UserHeaderName, reqCtx.Login) } - - middlewares := []sdkhttpclient.Middleware{} - - if !reqCtx.IsAnonymous { - httpHeaders := http.Header{ - proxyutil.UserHeaderName: []string{reqCtx.Login}, - } - - middlewares = append(middlewares, httpclientprovider.SetHeadersMiddleware(httpHeaders)) - } else { - middlewares = append(middlewares, httpclientprovider.DeleteHeadersMiddleware(proxyutil.UserHeaderName)) - } - - ctx = sdkhttpclient.WithContextualMiddleware(ctx, middlewares...) - - return ctx } func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { @@ -61,7 +41,7 @@ func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.Query return m.next.QueryData(ctx, req) } - ctx = m.applyToken(ctx, req.PluginContext, req) + m.applyUserHeader(ctx, req) return m.next.QueryData(ctx, req) } @@ -71,7 +51,7 @@ func (m *UserHeaderMiddleware) CallResource(ctx context.Context, req *backend.Ca return m.next.CallResource(ctx, req, sender) } - ctx = m.applyToken(ctx, req.PluginContext, req) + m.applyUserHeader(ctx, req) return m.next.CallResource(ctx, req, sender) } @@ -81,7 +61,7 @@ func (m *UserHeaderMiddleware) CheckHealth(ctx context.Context, req *backend.Che return m.next.CheckHealth(ctx, req) } - ctx = m.applyToken(ctx, req.PluginContext, req) + m.applyUserHeader(ctx, req) return m.next.CheckHealth(ctx, req) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go index 8b49e1c0f7a..9fedd33a0e2 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util/proxyutil" @@ -39,10 +37,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Empty(t, cdt.QueryDataReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should not forward user header when calling CallResource", func(t *testing.T) { @@ -53,10 +47,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Empty(t, cdt.CallResourceReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) { @@ -67,10 +57,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Empty(t, cdt.CheckHealthReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) }) @@ -95,10 +81,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Empty(t, cdt.QueryDataReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should not forward user header when calling CallResource", func(t *testing.T) { @@ -109,10 +91,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Empty(t, cdt.CallResourceReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) { @@ -123,10 +101,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Empty(t, cdt.CheckHealthReq.Headers) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) }) }) @@ -156,10 +130,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should forward user header when calling CallResource", func(t *testing.T) { @@ -171,10 +141,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) { @@ -186,10 +152,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) }) @@ -214,10 +176,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should forward user header when calling CallResource", func(t *testing.T) { @@ -229,10 +187,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) { @@ -244,10 +198,6 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName)) - - middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) - require.Len(t, middlewares, 1) - require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) }) }) }) diff --git a/pkg/services/pluginsintegration/pluginsintegration.go b/pkg/services/pluginsintegration/pluginsintegration.go index efe42665e81..daf48380260 100644 --- a/pkg/services/pluginsintegration/pluginsintegration.go +++ b/pkg/services/pluginsintegration/pluginsintegration.go @@ -81,5 +81,7 @@ func CreateMiddlewares(cfg *setting.Cfg, oAuthTokenService oauthtoken.OAuthToken middlewares = append(middlewares, clientmiddleware.NewUserHeaderMiddleware()) } + middlewares = append(middlewares, clientmiddleware.NewHTTPClientMiddleware()) + return middlewares } diff --git a/pkg/tests/api/plugins/backendplugin/backendplugin_test.go b/pkg/tests/api/plugins/backendplugin/backendplugin_test.go index c7169aa07b9..d0d97c8cf99 100644 --- a/pkg/tests/api/plugins/backendplugin/backendplugin_test.go +++ b/pkg/tests/api/plugins/backendplugin/backendplugin_test.go @@ -26,92 +26,226 @@ import ( "golang.org/x/oauth2" ) +const loginCookieName = "grafana_session" + func TestIntegrationBackendPlugins(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } - regularQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { - t.Helper() - - return metricRequestWithQueries(t, fmt.Sprintf(`{ - "datasource": { - "uid": "%s" - } - }`, tsCtx.uid)) + oauthToken := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + RefreshToken: "refresh-token", + Expiry: time.Now().UTC().Add(24 * time.Hour), } + oauthToken = oauthToken.WithExtra(map[string]interface{}{"id_token": "id-token"}) - expressionQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { - t.Helper() + newTestScenario(t, "Datasource with no custom HTTP settings", + options( + withIncomingRequest(func(req *http.Request) { + req.Header.Set("X-Custom", "custom") + req.AddCookie(&http.Cookie{Name: "cookie1"}) + req.AddCookie(&http.Cookie{Name: "cookie2"}) + req.AddCookie(&http.Cookie{Name: "cookie3"}) + req.AddCookie(&http.Cookie{Name: loginCookieName}) + }), + ), + func(t *testing.T, tsCtx *testScenarioContext) { + verify := func(h backend.ForwardHTTPHeaders) { + require.NotNil(t, h) + require.Empty(t, h.GetHTTPHeader(backend.CookiesHeaderName)) + require.Empty(t, h.GetHTTPHeader("Authorization")) - return metricRequestWithQueries(t, fmt.Sprintf(`{ - "refId": "A", - "datasource": { - "uid": "%s", - "type": "%s" + require.NotNil(t, tsCtx.outgoingRequest) + require.NotEmpty(t, tsCtx.outgoingRequest.Header) } - }`, tsCtx.uid, tsCtx.testPluginID), `{ - "refId": "B", - "datasource": { - "type": "__expr__", - "uid": "__expr__", - "name": "Expression" - }, - "type": "math", - "expression": "$A - 50" - }`) - } - newTestScenario(t, "When oauth token not available", func(t *testing.T, tsCtx *testScenarioContext) { - tsCtx.testEnv.OAuthTokenService.Token = nil + tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) { + verify(pReq) + }) - tsCtx.runCheckHealthTest(t) - tsCtx.runCallResourceTest(t) + tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) { + verify(pReq) + require.Equal(t, "custom", pReq.GetHTTPHeader("X-Custom")) + require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom")) + }) - t.Run("regular query", func(t *testing.T) { - tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx)) + verifyQueryData := func(pReq *backend.QueryDataRequest) { + verify(pReq) + } + + t.Run("regular query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData) + + t.Run("expression query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData) + }) + }) }) - t.Run("expression query", func(t *testing.T) { - tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx)) - }) - }) + newTestScenario(t, "Datasource with most HTTP settings set except oauthPassThru and oauth token available", + options( + withIncomingRequest(func(req *http.Request) { + req.AddCookie(&http.Cookie{Name: "cookie1"}) + req.AddCookie(&http.Cookie{Name: "cookie2"}) + req.AddCookie(&http.Cookie{Name: "cookie3"}) + req.AddCookie(&http.Cookie{Name: loginCookieName}) + }), + withOAuthToken(oauthToken), + withDsBasicAuth("basicAuthUser", "basicAuthPassword"), + withDsCustomHeader(map[string]string{"X-CUSTOM-HEADER": "custom-header-value"}), + withDsCookieForwarding([]string{"cookie1", "cookie3", loginCookieName}), + ), + func(t *testing.T, tsCtx *testScenarioContext) { + verify := func(h backend.ForwardHTTPHeaders) { + require.NotNil(t, h) + require.Equal(t, "cookie1=; cookie3=", h.GetHTTPHeader(backend.CookiesHeaderName)) + require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)) + require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName)) - newTestScenario(t, "When oauth token available", func(t *testing.T, tsCtx *testScenarioContext) { - token := &oauth2.Token{ - TokenType: "bearer", - AccessToken: "access-token", - RefreshToken: "refresh-token", - Expiry: time.Now().UTC().Add(24 * time.Hour), - } - token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) - tsCtx.testEnv.OAuthTokenService.Token = token + require.NotNil(t, tsCtx.outgoingRequest) + require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get(backend.CookiesHeaderName)) + require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) - tsCtx.runCheckHealthTest(t) - tsCtx.runCallResourceTest(t) + username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() + require.True(t, ok) + require.Equal(t, "basicAuthUser", username) + require.Equal(t, "basicAuthPassword", pwd) + } - t.Run("regular query", func(t *testing.T) { - tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx)) + tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) { + verify(pReq) + }) + + tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) { + verify(pReq) + }) + + verifyQueryData := func(pReq *backend.QueryDataRequest) { + verify(pReq) + } + + t.Run("regular query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData) + + t.Run("expression query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData) + }) + }) }) - t.Run("expression query", func(t *testing.T) { - tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx)) + newTestScenario(t, "Datasource with oauthPassThru and basic auth configured and oauth token available", + options( + withOAuthToken(oauthToken), + withDsOAuthForwarding(), + withDsBasicAuth("basicAuthUser", "basicAuthPassword"), + ), + func(t *testing.T, tsCtx *testScenarioContext) { + verify := func(h backend.ForwardHTTPHeaders) { + require.NotNil(t, h) + + expectedAuthHeader := fmt.Sprintf("Bearer %s", oauthToken.AccessToken) + expectedTokenHeader := oauthToken.Extra("id_token").(string) + + require.Equal(t, expectedAuthHeader, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName)) + require.Equal(t, expectedTokenHeader, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName)) + + require.NotNil(t, tsCtx.outgoingRequest) + require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityTokenHeaderName)) + require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityIDTokenHeaderName)) + } + + tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) { + verify(pReq) + }) + + tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) { + verify(pReq) + }) + + verifyQueryData := func(pReq *backend.QueryDataRequest) { + verify(pReq) + } + + t.Run("regular query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData) + + t.Run("expression query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData) + }) + }) }) - }) } type testScenarioContext struct { - testPluginID string - uid string - grafanaListeningAddr string - testEnv *server.TestEnv - outgoingServer *httptest.Server - outgoingRequest *http.Request - backendTestPlugin *testPlugin - rt http.RoundTripper + testPluginID string + uid string + grafanaListeningAddr string + testEnv *server.TestEnv + outgoingServer *httptest.Server + outgoingRequest *http.Request + backendTestPlugin *testPlugin + rt http.RoundTripper + modifyIncomingRequest func(req *http.Request) } -func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx *testScenarioContext)) { +type testScenarioInput struct { + ds *datasources.AddDataSourceCommand + token *oauth2.Token + modifyIncomingRequest func(req *http.Request) +} + +type testScenarioOption func(*testScenarioInput) + +func options(opts ...testScenarioOption) []testScenarioOption { + return opts +} + +func withIncomingRequest(cb func(req *http.Request)) testScenarioOption { + return func(in *testScenarioInput) { + in.modifyIncomingRequest = cb + } +} + +func withOAuthToken(token *oauth2.Token) testScenarioOption { + return func(in *testScenarioInput) { + in.token = token + } +} + +func withDsBasicAuth(username, password string) testScenarioOption { + return func(in *testScenarioInput) { + in.ds.BasicAuth = true + in.ds.BasicAuthUser = username + in.ds.SecureJsonData["basicAuthPassword"] = password + } +} + +func withDsCustomHeader(headers map[string]string) testScenarioOption { + return func(in *testScenarioInput) { + index := 1 + for k, v := range headers { + in.ds.JsonData.Set(fmt.Sprintf("httpHeaderName%d", index), k) + in.ds.SecureJsonData[fmt.Sprintf("httpHeaderValue%d", index)] = v + index++ + } + } +} + +func withDsOAuthForwarding() testScenarioOption { + return func(in *testScenarioInput) { + in.ds.JsonData.Set("oauthPassThru", true) + } +} + +func withDsCookieForwarding(names []string) testScenarioOption { + return func(in *testScenarioInput) { + in.ds.JsonData.Set("keepCookies", names) + } +} + +func newTestScenario(t *testing.T, name string, opts []testScenarioOption, callback func(t *testing.T, ctx *testScenarioContext)) { tsCtx := testScenarioContext{ testPluginID: "test-plugin", } @@ -123,6 +257,7 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) tsCtx.grafanaListeningAddr = grafanaListeningAddr + testEnv.SQLStore.Cfg.LoginCookieName = loginCookieName tsCtx.testEnv = testEnv ctx := context.Background() @@ -143,29 +278,30 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx err := testEnv.PluginRegistry.Add(ctx, testPlugin) require.NoError(t, err) - jsonData := simplejson.NewFromAny(map[string]interface{}{ - "httpHeaderName1": "X-CUSTOM-HEADER", - "oauthPassThru": true, - "keepCookies": []string{"cookie1", "cookie3", "grafana_session"}, - }) - secureJSONData := map[string]string{ - "basicAuthPassword": "basicAuthPassword", - "httpHeaderValue1": "custom-header-value", - } + jsonData := simplejson.New() + secureJSONData := map[string]string{} tsCtx.uid = "test-plugin" - err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, &datasources.AddDataSourceCommand{ + cmd := &datasources.AddDataSourceCommand{ OrgId: 1, Access: datasources.DS_ACCESS_PROXY, Name: "TestPlugin", Type: tsCtx.testPluginID, Uid: tsCtx.uid, Url: tsCtx.outgoingServer.URL, - BasicAuth: true, - BasicAuthUser: "basicAuthUser", JsonData: jsonData, SecureJsonData: secureJSONData, - }) + } + + in := &testScenarioInput{ds: cmd} + for _, opt := range opts { + opt(in) + } + + tsCtx.modifyIncomingRequest = in.modifyIncomingRequest + tsCtx.testEnv.OAuthTokenService.Token = in.token + + err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, cmd) require.NoError(t, err) getDataSourceQuery := &datasources.GetDataSourceQuery{ @@ -185,17 +321,42 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx }) } -func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest) { - t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) { - var received *struct { - ctx context.Context - req *backend.QueryDataRequest +func createRegularQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { + t.Helper() + + return metricRequestWithQueries(t, fmt.Sprintf(`{ + "datasource": { + "uid": "%s" } + }`, tsCtx.uid)) +} + +func createExpressionQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { + t.Helper() + + return metricRequestWithQueries(t, fmt.Sprintf(`{ + "refId": "A", + "datasource": { + "uid": "%s", + "type": "%s" + } + }`, tsCtx.uid, tsCtx.testPluginID), `{ + "refId": "B", + "datasource": { + "type": "__expr__", + "uid": "__expr__", + "name": "Expression" + }, + "type": "math", + "expression": "$A - 50" + }`) +} + +func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest, callback func(req *backend.QueryDataRequest)) { + t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) { + var received *backend.QueryDataRequest tsCtx.backendTestPlugin.QueryDataHandler = backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - received = &struct { - ctx context.Context - req *backend.QueryDataRequest - }{ctx, req} + received = req c := http.Client{ Transport: tsCtx.rt, @@ -227,18 +388,11 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR req, err := http.NewRequest(http.MethodPost, u, buf1) req.Header.Set("Content-Type", "application/json") - req.AddCookie(&http.Cookie{ - Name: "cookie1", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie2", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie3", - }) - req.AddCookie(&http.Cookie{ - Name: "grafana_session", - }) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36") + + if tsCtx.modifyIncomingRequest != nil { + tsCtx.modifyIncomingRequest(req) + } require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -253,51 +407,18 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR _, err = io.ReadAll(resp.Body) require.NoError(t, err) - // backend query data request - require.NotNil(t, received) - require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"]) + require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding")) + require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent")) - token := tsCtx.testEnv.OAuthTokenService.Token - - var expectedAuthHeader string - var expectedTokenHeader string - - if token != nil { - expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) - expectedTokenHeader = token.Extra("id_token").(string) - - require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"]) - require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"]) - } - - // outgoing HTTP request - require.NotNil(t, tsCtx.outgoingRequest) - require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) - require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) - - if token == nil { - username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() - require.True(t, ok) - require.Equal(t, "basicAuthUser", username) - require.Equal(t, "basicAuthPassword", pwd) - } else { - require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) - require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) - } + callback(received) }) } -func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) { +func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T, callback func(req *backend.CheckHealthRequest)) { t.Run("When calling /api/datasources/uid/:uid/health should set expected headers on outgoing CheckHealth and HTTP request", func(t *testing.T) { - var received *struct { - ctx context.Context - req *backend.CheckHealthRequest - } + var received *backend.CheckHealthRequest tsCtx.backendTestPlugin.CheckHealthHandler = backend.CheckHealthHandlerFunc(func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { - received = &struct { - ctx context.Context - req *backend.CheckHealthRequest - }{ctx, req} + received = req c := http.Client{ Transport: tsCtx.rt, @@ -328,18 +449,11 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) { req, err := http.NewRequest(http.MethodGet, u, nil) req.Header.Set("Content-Type", "application/json") - req.AddCookie(&http.Cookie{ - Name: "cookie1", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie2", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie3", - }) - req.AddCookie(&http.Cookie{ - Name: "grafana_session", - }) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36") + + if tsCtx.modifyIncomingRequest != nil { + tsCtx.modifyIncomingRequest(req) + } require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -354,51 +468,18 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) { _, err = io.ReadAll(resp.Body) require.NoError(t, err) - // backend query data request - require.NotNil(t, received) - require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"]) + require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding")) + require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent")) - token := tsCtx.testEnv.OAuthTokenService.Token - - var expectedAuthHeader string - var expectedTokenHeader string - - if token != nil { - expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) - expectedTokenHeader = token.Extra("id_token").(string) - - require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"]) - require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"]) - } - - // outgoing HTTP request - require.NotNil(t, tsCtx.outgoingRequest) - require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) - require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) - - if token == nil { - username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() - require.True(t, ok) - require.Equal(t, "basicAuthUser", username) - require.Equal(t, "basicAuthPassword", pwd) - } else { - require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) - require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) - } + callback(received) }) } -func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) { +func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T, callback func(req *backend.CallResourceRequest)) { t.Run("When calling /api/datasources/uid/:uid/resources should set expected headers on outgoing CallResource and HTTP request", func(t *testing.T) { - var received *struct { - ctx context.Context - req *backend.CallResourceRequest - } + var received *backend.CallResourceRequest tsCtx.backendTestPlugin.CallResourceHandler = backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { - received = &struct { - ctx context.Context - req *backend.CallResourceRequest - }{ctx, req} + received = req c := http.Client{ Transport: tsCtx.rt, @@ -450,19 +531,11 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) { req.Header.Set("Connection", "X-Some-Conn-Header") req.Header.Set("X-Some-Conn-Header", "should be deleted") req.Header.Set("Proxy-Connection", "should be deleted") - req.Header.Set("X-Custom", "custom") - req.AddCookie(&http.Cookie{ - Name: "cookie1", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie2", - }) - req.AddCookie(&http.Cookie{ - Name: "cookie3", - }) - req.AddCookie(&http.Cookie{ - Name: "grafana_session", - }) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36") + + if tsCtx.modifyIncomingRequest != nil { + tsCtx.modifyIncomingRequest(req) + } require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -484,45 +557,18 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) { require.Empty(t, resp.Header.Get("Set-Cookie")) require.Equal(t, "should not be deleted", resp.Header.Get("X-Custom")) - // backend query data request require.NotNil(t, received) - require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"][0]) - require.Empty(t, received.req.Headers["Connection"]) - require.Empty(t, received.req.Headers["X-Some-Conn-Header"]) - require.Empty(t, received.req.Headers["Proxy-Connection"]) - require.Equal(t, "custom", received.req.Headers["X-Custom"][0]) + require.Empty(t, received.Headers["Connection"]) + require.Empty(t, received.Headers["X-Some-Conn-Header"]) + require.Empty(t, received.Headers["Proxy-Connection"]) - token := tsCtx.testEnv.OAuthTokenService.Token - - var expectedAuthHeader string - var expectedTokenHeader string - - if token != nil { - expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) - expectedTokenHeader = token.Extra("id_token").(string) - - require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"][0]) - require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"][0]) - } - - // outgoing HTTP request - require.NotNil(t, tsCtx.outgoingRequest) - require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) require.Empty(t, tsCtx.outgoingRequest.Header.Get("Connection")) require.Empty(t, tsCtx.outgoingRequest.Header.Get("X-Some-Conn-Header")) require.Empty(t, tsCtx.outgoingRequest.Header.Get("Proxy-Connection")) - require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom")) - require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) + require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding")) + require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent")) - if token == nil { - username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() - require.True(t, ok) - require.Equal(t, "basicAuthUser", username) - require.Equal(t, "basicAuthPassword", pwd) - } else { - require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) - require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) - } + callback(received) }) } diff --git a/pkg/tsdb/cloudwatch/cloudwatch.go b/pkg/tsdb/cloudwatch/cloudwatch.go index d82fcb902b1..c87b77717b7 100644 --- a/pkg/tsdb/cloudwatch/cloudwatch.go +++ b/pkg/tsdb/cloudwatch/cloudwatch.go @@ -23,6 +23,7 @@ import ( "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/services/featuremgmt" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/cloudwatch/clients" "github.com/grafana/grafana/pkg/tsdb/cloudwatch/models" @@ -162,7 +163,7 @@ func (e *cloudWatchExecutor) QueryData(ctx context.Context, req *backend.QueryDa if err != nil { return nil, err } - _, fromAlert := req.Headers["FromAlert"] + _, fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName] isLogAlertQuery := fromAlert && model.QueryMode == logsQueryMode if isLogAlertQuery { diff --git a/pkg/tsdb/cloudwatch/cloudwatch_test.go b/pkg/tsdb/cloudwatch/cloudwatch_test.go index c12ca03d5f9..caf6e1274d8 100644 --- a/pkg/tsdb/cloudwatch/cloudwatch_test.go +++ b/pkg/tsdb/cloudwatch/cloudwatch_test.go @@ -20,6 +20,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/services/featuremgmt" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/tsdb/cloudwatch/mocks" "github.com/grafana/grafana/pkg/tsdb/cloudwatch/models" "github.com/grafana/grafana/pkg/tsdb/cloudwatch/utils" @@ -212,7 +213,7 @@ func Test_executeLogAlertQuery(t *testing.T) { executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures()) _, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{ - Headers: map[string]string{"FromAlert": "some value"}, + Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"}, PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}}, Queries: []backend.DataQuery{ { @@ -238,7 +239,7 @@ func Test_executeLogAlertQuery(t *testing.T) { executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures()) _, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{ - Headers: map[string]string{"FromAlert": "some value"}, + Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"}, PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}}, Queries: []backend.DataQuery{ { diff --git a/pkg/tsdb/loki/api.go b/pkg/tsdb/loki/api.go index 15743e5efeb..681c4f1b1e4 100644 --- a/pkg/tsdb/loki/api.go +++ b/pkg/tsdb/loki/api.go @@ -18,10 +18,9 @@ import ( ) type LokiAPI struct { - client *http.Client - url string - log log.Logger - headers map[string]string + client *http.Client + url string + log log.Logger } type RawLokiResponse struct { @@ -29,17 +28,11 @@ type RawLokiResponse struct { Encoding string } -func newLokiAPI(client *http.Client, url string, log log.Logger, headers map[string]string) *LokiAPI { - return &LokiAPI{client: client, url: url, log: log, headers: headers} +func newLokiAPI(client *http.Client, url string, log log.Logger) *LokiAPI { + return &LokiAPI{client: client, url: url, log: log} } -func addHeaders(req *http.Request, headers map[string]string) { - for name, value := range headers { - req.Header.Set(name, value) - } -} - -func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, headers map[string]string) (*http.Request, error) { +func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*http.Request, error) { qs := url.Values{} qs.Set("query", query.Expr) @@ -92,8 +85,6 @@ func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, hea return nil, err } - addHeaders(req, headers) - if query.VolumeQuery { req.Header.Set("X-Query-Tags", "Source=logvolhist") } @@ -145,7 +136,7 @@ func makeLokiError(body io.ReadCloser) error { } func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames, error) { - req, err := makeDataRequest(ctx, api.url, query, api.headers) + req, err := makeDataRequest(ctx, api.url, query) if err != nil { return nil, err } @@ -183,7 +174,7 @@ func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames return res.Frames, nil } -func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string, headers map[string]string) (*http.Request, error) { +func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string) (*http.Request, error) { lokiUrl, err := url.Parse(lokiDsUrl) if err != nil { return nil, err @@ -204,13 +195,11 @@ func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string, return nil, err } - addHeaders(req, headers) - return req, nil } func (api *LokiAPI) RawQuery(ctx context.Context, resourcePath string) (RawLokiResponse, error) { - req, err := makeRawRequest(ctx, api.url, resourcePath, api.headers) + req, err := makeRawRequest(ctx, api.url, resourcePath) if err != nil { return RawLokiResponse{}, err } diff --git a/pkg/tsdb/loki/api_mock.go b/pkg/tsdb/loki/api_mock.go index 425c0a885a0..0548f798c61 100644 --- a/pkg/tsdb/loki/api_mock.go +++ b/pkg/tsdb/loki/api_mock.go @@ -64,7 +64,7 @@ func makeMockedAPIWithUrl(url string, statusCode int, contentType string, respon Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback}, } - return newLokiAPI(&client, url, log.New("test"), nil) + return newLokiAPI(&client, url, log.New("test")) } func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType string, responseBytes []byte, requestCallback mockRequestCallback) *LokiAPI { @@ -72,5 +72,5 @@ func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType stri Transport: &mockedCompressedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback}, } - return newLokiAPI(&client, url, log.New("test"), nil) + return newLokiAPI(&client, url, log.New("test")) } diff --git a/pkg/tsdb/loki/auth_test.go b/pkg/tsdb/loki/auth_test.go deleted file mode 100644 index caaee66b6ef..00000000000 --- a/pkg/tsdb/loki/auth_test.go +++ /dev/null @@ -1,201 +0,0 @@ -package loki - -import ( - "bytes" - "context" - "io" - "net/http" - "testing" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/stretchr/testify/require" - - "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/infra/tracing" -) - -type mockedRoundTripperForOauth struct { - requestCallback func(req *http.Request) - body []byte -} - -func (mockedRT *mockedRoundTripperForOauth) RoundTrip(req *http.Request) (*http.Response, error) { - mockedRT.requestCallback(req) - return &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{}, - Body: io.NopCloser(bytes.NewReader(mockedRT.body)), - }, nil -} - -type mockedCallResourceResponseSenderForOauth struct { - Response *backend.CallResourceResponse -} - -func (s *mockedCallResourceResponseSenderForOauth) Send(resp *backend.CallResourceResponse) error { - s.Response = resp - return nil -} - -func makeMockedDsInfoForOauth(body []byte, requestCallback func(req *http.Request)) datasourceInfo { - client := http.Client{ - Transport: &mockedRoundTripperForOauth{requestCallback: requestCallback, body: body}, - } - - return datasourceInfo{ - HTTPClient: &client, - } -} - -func TestOauthForwardIdentity(t *testing.T) { - tt := []struct { - name string - auth bool - cookie bool - }{ - {name: "when auth headers exist => add auth headers", auth: true, cookie: false}, - {name: "when cookie header exists => add cookie header", auth: false, cookie: true}, - {name: "when cookie&auth headers exist => add cookie&auth headers", auth: true, cookie: true}, - {name: "when no header exists => do not add headers", auth: false, cookie: false}, - } - - authName := "Authorization" - authValue := "auth" - cookieName := "Cookie" - cookieValue := "a=1" - idTokenName := "X-ID-Token" - idTokenValue := "idtoken" - - for _, test := range tt { - t.Run("QueryData: "+test.name, func(t *testing.T) { - response := []byte(` - { - "status": "success", - "data": { - "resultType": "streams", - "result": [ - { - "stream": {}, - "values": [ - ["1", "line1"] - ] - } - ] - } - } - `) - - clientUsed := false - dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) { - clientUsed = true - // we need to check for "header does not exist", - // and the only way i can find is to get the values - // as an array - authValues := req.Header.Values(authName) - cookieValues := req.Header.Values(cookieName) - idTokenValues := req.Header.Values(idTokenName) - if test.auth { - require.Equal(t, []string{authValue}, authValues) - require.Equal(t, []string{idTokenValue}, idTokenValues) - } else { - require.Len(t, authValues, 0) - require.Len(t, idTokenValues, 0) - } - if test.cookie { - require.Equal(t, []string{cookieValue}, cookieValues) - } else { - require.Len(t, cookieValues, 0) - } - }) - - req := backend.QueryDataRequest{ - Headers: map[string]string{}, - Queries: []backend.DataQuery{ - { - RefID: "A", - JSON: []byte("{}"), - }, - }, - } - - if test.auth { - req.Headers[authName] = authValue - req.Headers[idTokenName] = idTokenValue - } - - if test.cookie { - req.Headers[cookieName] = cookieValue - } - - tracer := tracing.InitializeTracerForTest() - - data, err := queryData(context.Background(), &req, &dsInfo, tracer) - // we do a basic check that the result is OK - require.NoError(t, err) - require.Len(t, data.Responses, 1) - res := data.Responses["A"] - require.NoError(t, res.Error) - require.Len(t, res.Frames, 1) - require.Equal(t, "line1", res.Frames[0].Fields[2].At(0)) - - // we need to be sure the client-callback was triggered - require.True(t, clientUsed) - }) - } - - for _, test := range tt { - t.Run("CallResource: "+test.name, func(t *testing.T) { - response := []byte("mocked resource response") - - clientUsed := false - dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) { - clientUsed = true - authValues := req.Header.Values(authName) - cookieValues := req.Header.Values(cookieName) - idTokenValues := req.Header.Values(idTokenName) - // we need to check for "header does not exist", - // and the only way i can find is to get the values - // as an array - if test.auth { - require.Equal(t, []string{authValue}, authValues) - require.Equal(t, []string{idTokenValue}, idTokenValues) - } else { - require.Len(t, authValues, 0) - require.Len(t, idTokenValues, 0) - } - if test.cookie { - require.Equal(t, []string{cookieValue}, cookieValues) - } else { - require.Len(t, cookieValues, 0) - } - }) - - req := backend.CallResourceRequest{ - Headers: map[string][]string{}, - Method: "GET", - URL: "labels?", - } - - if test.auth { - req.Headers[authName] = []string{authValue} - req.Headers[idTokenName] = []string{idTokenValue} - } - if test.cookie { - req.Headers[cookieName] = []string{cookieValue} - } - - sender := &mockedCallResourceResponseSenderForOauth{} - - err := callResource(context.Background(), &req, sender, &dsInfo, log.New("testlog")) - // we do a basic check that the result is OK - require.NoError(t, err) - sent := sender.Response - require.NotNil(t, sent) - require.Equal(t, http.StatusOK, sent.Status) - require.Equal(t, response, sent.Body) - - // we need to be sure the client-callback was triggered - require.True(t, clientUsed) - }) - } -} diff --git a/pkg/tsdb/loki/loki.go b/pkg/tsdb/loki/loki.go index 6742398a226..6925a91c25b 100644 --- a/pkg/tsdb/loki/loki.go +++ b/pkg/tsdb/loki/loki.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "net/textproto" "regexp" "strings" "sync" @@ -96,22 +95,6 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst } } -// in the CallResource API, request-headers are in a map where the value is an array-of-strings, -// so we need a helper function that can extract a single string-value from an array-of-strings. -// i only deal with two cases: -// - zero-length array -// - first-item of the array -// i do not handle the case where there are multiple items in the array, i do not know -// if that can even happen ever, for the headers that we are interested in. -func arrayHeaderFirstValue(values []string) string { - if len(values) == 0 { - return "" - } - - // NOTE: we assume there never is a second item in the http-header-values-array - return values[0] -} - func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { dsInfo, err := s.getDSInfo(req.PluginContext) if err != nil { @@ -120,30 +103,6 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq return callResource(ctx, req, sender, dsInfo, logger.FromContext(ctx)) } -func getHeadersForCallResource(headers map[string][]string) map[string]string { - data := make(map[string]string) - - for k, values := range headers { - k = textproto.CanonicalMIMEHeaderKey(k) - firstValue := arrayHeaderFirstValue(values) - - if firstValue == "" { - continue - } - switch k { - case "Authorization": - data["Authorization"] = firstValue - case "X-Id-Token": - data["X-ID-Token"] = firstValue - case "Cookie": - data["Cookie"] = firstValue - case "Accept-Encoding": - data["Accept-Encoding"] = firstValue - } - } - return data -} - func callResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, dsInfo *datasourceInfo, plog log.Logger) error { url := req.URL @@ -158,7 +117,7 @@ func callResource(ctx context.Context, req *backend.CallResourceRequest, sender } lokiURL := fmt.Sprintf("/loki/api/v1/%s", url) - api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getHeadersForCallResource(req.Headers)) + api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog) encodedBytes, err := api.RawQuery(ctx, lokiURL) if err != nil { @@ -191,7 +150,7 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) func queryData(ctx context.Context, req *backend.QueryDataRequest, dsInfo *datasourceInfo, tracer tracing.Tracer) (*backend.QueryDataResponse, error) { result := backend.NewQueryDataResponse() - api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx), req.Headers) + api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx)) queries, err := parseQuery(req) if err != nil { diff --git a/pkg/tsdb/loki/loki_test.go b/pkg/tsdb/loki/loki_test.go deleted file mode 100644 index a7aa8a29524..00000000000 --- a/pkg/tsdb/loki/loki_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package loki - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetHeadersForCallResource(t *testing.T) { - const idTokn1 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - const idTokn2 = "eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU" - const authTokn1 = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - const authTokn2 = "Bearer eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU" - - testCases := map[string]struct { - headers map[string][]string - expectedHeaders map[string]string - }{ - "Headers with empty value": { - headers: map[string][]string{ - "X-Grafana-Org-Id": {"1"}, - "Cookie": {""}, - "X-Id-Token": {""}, - "Accept-Encoding": {""}, - "Authorization": {""}, - }, - expectedHeaders: map[string]string{}, - }, - "Headers with multiple values": { - headers: map[string][]string{ - "Authorization": {authTokn1, authTokn2}, - "Cookie": {"a=1"}, - "X-Grafana-Org-Id": {"1"}, - "Accept-Encoding": {"gzip", "compress"}, - "X-Id-Token": {idTokn1, idTokn2}, - }, - expectedHeaders: map[string]string{ - "Authorization": authTokn1, - "Cookie": "a=1", - "Accept-Encoding": "gzip", - "X-ID-Token": idTokn1, - }, - }, - "Headers with single value": { - headers: map[string][]string{ - "Authorization": {authTokn1}, - "X-Grafana-Org-Id": {"1"}, - "Cookie": {"a=1"}, - "Accept-Encoding": {"gzip"}, - "X-Id-Token": {idTokn1}, - }, - expectedHeaders: map[string]string{ - "Authorization": authTokn1, - "Cookie": "a=1", - "Accept-Encoding": "gzip", - "X-ID-Token": idTokn1, - }, - }, - "Non Canonical 'X-Id-Token' header key": { - headers: map[string][]string{ - "X-ID-TOKEN": {idTokn1}, - }, - expectedHeaders: map[string]string{ - "X-ID-Token": idTokn1, - }, - }, - } - for name, test := range testCases { - t.Run(name, func(t *testing.T) { - headers := getHeadersForCallResource(test.headers) - assert.Equal(t, test.expectedHeaders, headers) - }) - } -} diff --git a/pkg/tsdb/prometheus/client/client.go b/pkg/tsdb/prometheus/client/client.go index c3f0c4d0719..cc9a40ef21e 100644 --- a/pkg/tsdb/prometheus/client/client.go +++ b/pkg/tsdb/prometheus/client/client.go @@ -32,7 +32,7 @@ func NewClient(d doer, method, baseUrl string) *Client { return &Client{doer: d, method: method, baseUrl: baseUrl} } -func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) { +func (c *Client) QueryRange(ctx context.Context, q *models.Query) (*http.Response, error) { tr := q.TimeRange() qv := map[string]string{ "query": q.Expr, @@ -41,7 +41,7 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H "step": strconv.FormatFloat(tr.Step.Seconds(), 'f', -1, 64), } - req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv, headers) + req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv) if err != nil { return nil, err } @@ -49,14 +49,14 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H return c.doer.Do(req) } -func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) { +func (c *Client) QueryInstant(ctx context.Context, q *models.Query) (*http.Response, error) { qv := map[string]string{"query": q.Expr} tr := q.TimeRange() if !tr.End.IsZero() { qv["time"] = formatTime(tr.End) } - req, err := c.createQueryRequest(ctx, "api/v1/query", qv, headers) + req, err := c.createQueryRequest(ctx, "api/v1/query", qv) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http return c.doer.Do(req) } -func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) { +func (c *Client) QueryExemplars(ctx context.Context, q *models.Query) (*http.Response, error) { tr := q.TimeRange() qv := map[string]string{ "query": q.Expr, @@ -72,7 +72,7 @@ func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers ht "end": formatTime(tr.End), } - req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv, headers) + req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq // We use method from the request, as for resources front end may do a fallback to GET if POST does not work // nad we want to respect that. - httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body), req.Headers) + httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body)) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq return c.doer.Do(httpRequest) } -func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string, headers http.Header) (*http.Request, error) { +func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string) (*http.Request, error) { if strings.ToUpper(c.method) == http.MethodPost { u, err := c.createUrl(endpoint, nil) if err != nil { @@ -115,7 +115,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map v.Set(key, val) } - return createRequest(ctx, c.method, u, strings.NewReader(v.Encode()), headers) + return createRequest(ctx, c.method, u, strings.NewReader(v.Encode())) } u, err := c.createUrl(endpoint, qv) @@ -123,7 +123,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map return nil, err } - return createRequest(ctx, c.method, u, http.NoBody, headers) + return createRequest(ctx, c.method, u, http.NoBody) } func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, error) { @@ -148,15 +148,12 @@ func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, err return finalUrl, nil } -func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader, header http.Header) (*http.Request, error) { +func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader) (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, method, u.String(), bodyReader) if err != nil { return nil, err } - // request.Header is created empty from NewRequestWithContext so we can just replace it - if header != nil { - request.Header = header - } + if strings.ToUpper(method) == http.MethodPost { // This may not be true but right now we don't have more information here and seems like we send just this type // of encoding right now if it is a POST diff --git a/pkg/tsdb/prometheus/client/client_test.go b/pkg/tsdb/prometheus/client/client_test.go index 9e18008b422..e9928909e93 100644 --- a/pkg/tsdb/prometheus/client/client_test.go +++ b/pkg/tsdb/prometheus/client/client_test.go @@ -37,7 +37,6 @@ func TestClient(t *testing.T) { Path: "/api/v1/series", Method: http.MethodPost, URL: "/api/v1/series", - Headers: nil, Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"), } res, err := client.QueryResource(context.Background(), req) @@ -63,7 +62,6 @@ func TestClient(t *testing.T) { Path: "/api/v1/series", Method: http.MethodGet, URL: "api/v1/series?match%5B%5D=ALERTS&start=1655272558&end=1655294158", - Headers: nil, } res, err := client.QueryResource(context.Background(), req) defer func() { @@ -95,7 +93,7 @@ func TestClient(t *testing.T) { RangeQuery: true, Step: 1 * time.Second, } - res, err := client.QueryRange(context.Background(), req, nil) + res, err := client.QueryRange(context.Background(), req) defer func() { if res != nil && res.Body != nil { if err := res.Body.Close(); err != nil { @@ -122,7 +120,7 @@ func TestClient(t *testing.T) { RangeQuery: true, Step: 1 * time.Second, } - res, err := client.QueryRange(context.Background(), req, nil) + res, err := client.QueryRange(context.Background(), req) defer func() { if res != nil && res.Body != nil { if err := res.Body.Close(); err != nil { diff --git a/pkg/tsdb/prometheus/prometheus_test.go b/pkg/tsdb/prometheus/prometheus_test.go index 4cca2bacb95..5f9969ce237 100644 --- a/pkg/tsdb/prometheus/prometheus_test.go +++ b/pkg/tsdb/prometheus/prometheus_test.go @@ -85,11 +85,7 @@ func TestService(t *testing.T) { Path: "/api/v1/series", Method: http.MethodPost, URL: "/api/v1/series", - // This header should be passed on to the resource request - Headers: map[string][]string{ - "foo": {"bar"}, - }, - Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"), + Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"), } sender := &fakeSender{} @@ -100,7 +96,6 @@ func TestService(t *testing.T) { http.Header{ "Content-Type": {"application/x-www-form-urlencoded"}, "Idempotency-Key": []string(nil), - "foo": {"bar"}, }, httpProvider.Roundtripper.Req.Header) require.Equal(t, http.MethodPost, httpProvider.Roundtripper.Req.Method) diff --git a/pkg/tsdb/prometheus/querydata/request.go b/pkg/tsdb/prometheus/querydata/request.go index 4fc958e536b..78dccd59658 100644 --- a/pkg/tsdb/prometheus/querydata/request.go +++ b/pkg/tsdb/prometheus/querydata/request.go @@ -14,6 +14,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/services/featuremgmt" + ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/tsdb/intervalv2" "github.com/grafana/grafana/pkg/tsdb/prometheus/client" "github.com/grafana/grafana/pkg/tsdb/prometheus/models" @@ -79,7 +80,7 @@ func New( } func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - fromAlert := req.Headers["FromAlert"] == "true" + fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName] == "true" result := backend.QueryDataResponse{ Responses: backend.Responses{}, } @@ -89,7 +90,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest) if err != nil { return &result, err } - r, err := s.fetch(ctx, s.client, query, req.Headers) + r, err := s.fetch(ctx, s.client, query) if err != nil { return &result, err } @@ -103,7 +104,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest) return &result, nil } -func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) { +func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query) (*backend.DataResponse, error) { traceCtx, end := s.trace(ctx, q) defer end() @@ -116,7 +117,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models. } if q.InstantQuery { - res, err := s.instantQuery(traceCtx, client, q, headers) + res, err := s.instantQuery(traceCtx, client, q) if err != nil { return nil, err } @@ -125,7 +126,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models. } if q.RangeQuery { - res, err := s.rangeQuery(traceCtx, client, q, headers) + res, err := s.rangeQuery(traceCtx, client, q) if err != nil { return nil, err } @@ -140,7 +141,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models. } if q.ExemplarQuery { - res, err := s.exemplarQuery(traceCtx, client, q, headers) + res, err := s.exemplarQuery(traceCtx, client, q) if err != nil { // If exemplar query returns error, we want to only log it and // continue with other results processing @@ -154,24 +155,24 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models. return response, nil } -func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) { - res, err := c.QueryRange(ctx, q, sdkHeaderToHttpHeader(headers)) +func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) { + res, err := c.QueryRange(ctx, q) if err != nil { return nil, err } return s.parseResponse(ctx, q, res) } -func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) { - res, err := c.QueryInstant(ctx, q, sdkHeaderToHttpHeader(headers)) +func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) { + res, err := c.QueryInstant(ctx, q) if err != nil { return nil, err } return s.parseResponse(ctx, q, res) } -func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) { - res, err := c.QueryExemplars(ctx, q, sdkHeaderToHttpHeader(headers)) +func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) { + res, err := c.QueryExemplars(ctx, q) if err != nil { return nil, err } @@ -185,11 +186,3 @@ func (s *QueryData) trace(ctx context.Context, q *models.Query) (context.Context {Key: "stop_unixnano", Value: q.End, Kv: attribute.Key("stop_unixnano").Int64(q.End.UnixNano())}, }) } - -func sdkHeaderToHttpHeader(headers map[string]string) http.Header { - httpHeader := make(http.Header, len(headers)) - for key, val := range headers { - httpHeader.Set(key, val) - } - return httpHeader -} diff --git a/pkg/tsdb/prometheus/querydata/request_test.go b/pkg/tsdb/prometheus/querydata/request_test.go index 5a5239f5bc4..8005790781a 100644 --- a/pkg/tsdb/prometheus/querydata/request_test.go +++ b/pkg/tsdb/prometheus/querydata/request_test.go @@ -16,7 +16,6 @@ import ( "github.com/grafana/grafana/pkg/tsdb/prometheus/client" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" p "github.com/prometheus/common/model" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/httpclient" @@ -343,27 +342,6 @@ func TestPrometheus_parseTimeSeriesResponse(t *testing.T) { }) } -func TestPrometheusCanonicalHeaders(t *testing.T) { - // Ensure headers are always canonicalized for all outgoing requests - b, err := json.Marshal(models.QueryModel{}) - require.NoError(t, err) - query := backend.DataQuery{JSON: b} - tctx, err := setup(true) - require.NoError(t, err) - const idToken = "abc" - _, err = executeWithHeaders(tctx, query, queryResult{}, map[string]string{ - "X-Id-Token": idToken, - "X-ID-Token": idToken, - "X-Other": "thing", - }) - require.NoError(t, err) - assert.NotEmpty(t, tctx.httpProvider.req.Header) - // Check the request that hit the fake prometheus server to ensure headers are valid - assert.Equal(t, []string{idToken}, tctx.httpProvider.req.Header["X-Id-Token"]) - assert.Empty(t, tctx.httpProvider.req.Header["X-ID-Token"]) //nolint:staticcheck - assert.Equal(t, []string{"thing"}, tctx.httpProvider.req.Header["X-Other"]) -} - type queryResult struct { Type p.ValueType `json:"resultType"` Result interface{} `json:"result"`