mirror of
https://github.com/grafana/grafana.git
synced 2024-11-24 09:50:29 -06:00
Plugins: Automatically forward plugin request HTTP headers in outgoing HTTP requests (#60417)
Automatically forward core plugin request HTTP headers in outgoing HTTP requests. Core datasource plugin authors don't have to specifically handle forwarding of HTTP headers, e.g. do not have to "hardcode" the header-names in the datasource plugin, if not having custom needs. Fixes #57065
This commit is contained in:
parent
aaab477594
commit
c35c689a96
@ -1,27 +0,0 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
const DeleteHeadersMiddlewareName = "delete-headers"
|
||||
|
||||
// DeleteHeadersMiddleware middleware that delete headers on the outgoing
|
||||
// request if header names provided.
|
||||
func DeleteHeadersMiddleware(headerNames ...string) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(DeleteHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if len(headerNames) == 0 {
|
||||
return next
|
||||
}
|
||||
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
for _, k := range headerNames {
|
||||
req.Header.Del(k)
|
||||
}
|
||||
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
})
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeleteHeadersMiddleware(t *testing.T) {
|
||||
t.Run("Without headerNames should return next http.RoundTripper", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
var headerNames []string
|
||||
mw := DeleteHeadersMiddleware(headerNames...)
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
})
|
||||
|
||||
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("final")
|
||||
headerNames := []string{"X-Header-B", "X-Header-C"}
|
||||
mw := DeleteHeadersMiddleware(headerNames...)
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("X-Header-A", "a")
|
||||
req.Header.Set("X-Header-B", "b")
|
||||
req.Header.Set("X-Header-C", "c")
|
||||
req.Header.Set("X-Header-D", "d")
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
|
||||
|
||||
require.Equal(t, "a", req.Header.Get("X-Header-A"))
|
||||
require.Empty(t, req.Header.Get("X-Header-B"))
|
||||
require.Empty(t, req.Header.Get("X-Header-C"))
|
||||
require.Equal(t, "d", req.Header.Get("X-Header-D"))
|
||||
})
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
const SetHeadersMiddlewareName = "set-headers"
|
||||
|
||||
// SetHeadersMiddleware middleware that sets headers on the outgoing
|
||||
// request if headers provided.
|
||||
// If the request already contains any of the headers provided, they
|
||||
// will be overwritten.
|
||||
func SetHeadersMiddleware(headers http.Header) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(SetHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if len(headers) == 0 {
|
||||
return next
|
||||
}
|
||||
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range headers {
|
||||
canonicalKey := textproto.CanonicalMIMEHeaderKey(k)
|
||||
req.Header[canonicalKey] = v
|
||||
}
|
||||
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
})
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetHeadersMiddleware(t *testing.T) {
|
||||
t.Run("Without headers set should return next http.RoundTripper", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
var headers http.Header
|
||||
mw := SetHeadersMiddleware(headers)
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
})
|
||||
|
||||
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("final")
|
||||
headers := http.Header{
|
||||
"X-Header-A": []string{"value a"},
|
||||
"X-Header-B": []string{"value b"},
|
||||
"X-HEader-C": []string{"value c"},
|
||||
}
|
||||
mw := SetHeadersMiddleware(headers)
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("X-Header-B", "d")
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
|
||||
|
||||
require.Equal(t, "value a", req.Header.Get("X-Header-A"))
|
||||
require.Equal(t, "value b", req.Header.Get("X-Header-B"))
|
||||
require.Equal(t, "value c", req.Header.Get("X-Header-C"))
|
||||
})
|
||||
}
|
@ -255,6 +255,7 @@ var hopHeaders = []string{
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
"User-Agent",
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders removes hop-by-hop headers. Especially
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/services/alerting"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -276,8 +277,8 @@ func (c *QueryCondition) getRequestForAlertRule(datasource *datasources.DataSour
|
||||
},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"FromAlert": "true",
|
||||
"X-Cache-Skip": "true",
|
||||
ngalertmodels.FromAlertHeaderName: "true",
|
||||
ngalertmodels.CacheSkipHeaderName: "true",
|
||||
},
|
||||
Debug: debug,
|
||||
}
|
||||
|
@ -223,9 +223,9 @@ func buildDatasourceHeaders(ctx EvaluationContext) map[string]string {
|
||||
// Note: The spelling of this headers is intentionally degenerate from the others for compatibility reasons.
|
||||
// When sent over a network, the key of this header is canonicalized to "Fromalert".
|
||||
// However, some datasources still compare against the string "FromAlert".
|
||||
"FromAlert": "true",
|
||||
models.FromAlertHeaderName: "true",
|
||||
|
||||
"X-Cache-Skip": "true",
|
||||
models.CacheSkipHeaderName: "true",
|
||||
}
|
||||
|
||||
key, ok := models.RuleKeyFromContext(ctx.Ctx)
|
||||
|
21
pkg/services/ngalert/models/constants.go
Normal file
21
pkg/services/ngalert/models/constants.go
Normal file
@ -0,0 +1,21 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
// FromAlertHeaderName name of header added to datasource query requests
|
||||
// to denote request is originating from Grafana Alerting.
|
||||
//
|
||||
// Data sources might check this in query method as sometimes alerting
|
||||
// needs special considerations.
|
||||
// Several existing systems also compare against the value of this header.
|
||||
// Altering this constitutes a breaking change.
|
||||
//
|
||||
// Note: The spelling of this headers is intentionally degenerate from the
|
||||
// others for compatibility reasons. When sent over a network, the key of
|
||||
// this header is canonicalized to "Fromalert".
|
||||
// However, some datasources still compare against the string "FromAlert".
|
||||
FromAlertHeaderName = "FromAlert"
|
||||
|
||||
// CacheSkipHeaderName name of header added to datasource query requests
|
||||
// to denote request should not be cached.
|
||||
CacheSkipHeaderName = "X-Cache-Skip"
|
||||
)
|
@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
)
|
||||
@ -25,19 +23,19 @@ type ClearAuthHeadersMiddleware struct {
|
||||
next plugins.Client
|
||||
}
|
||||
|
||||
func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, pCtx backend.PluginContext, req interface{}) context.Context {
|
||||
func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, h backend.ForwardHTTPHeaders) {
|
||||
reqCtx := contexthandler.FromContext(ctx)
|
||||
// if no HTTP request context skip middleware
|
||||
if req == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
|
||||
return ctx
|
||||
if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
|
||||
return
|
||||
}
|
||||
|
||||
list := contexthandler.AuthHTTPHeaderListFromContext(ctx)
|
||||
if list != nil {
|
||||
ctx = sdkhttpclient.WithContextualMiddleware(ctx, httpclientprovider.DeleteHeadersMiddleware(list.Items...))
|
||||
for _, k := range list.Items {
|
||||
h.DeleteHTTPHeader(k)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
@ -45,7 +43,7 @@ func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.clearHeaders(ctx, req.PluginContext, req)
|
||||
m.clearHeaders(ctx, req)
|
||||
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
@ -55,7 +53,7 @@ func (m *ClearAuthHeadersMiddleware) CallResource(ctx context.Context, req *back
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
ctx = m.clearHeaders(ctx, req.PluginContext, req)
|
||||
m.clearHeaders(ctx, req)
|
||||
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
@ -65,7 +63,7 @@ func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backe
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.clearHeaders(ctx, req.PluginContext, req)
|
||||
m.clearHeaders(ctx, req)
|
||||
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
@ -1,14 +1,10 @@
|
||||
package clientmiddleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
@ -42,9 +38,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
|
||||
@ -55,9 +48,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
|
||||
@ -68,9 +58,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
})
|
||||
|
||||
@ -92,9 +79,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
|
||||
@ -105,9 +89,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
|
||||
@ -118,9 +99,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -155,18 +133,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
})
|
||||
|
||||
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
|
||||
@ -177,18 +144,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader])
|
||||
})
|
||||
|
||||
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
|
||||
@ -199,18 +155,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
})
|
||||
})
|
||||
|
||||
@ -220,12 +165,12 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
|
||||
)
|
||||
|
||||
const customHeader = "X-Custom"
|
||||
const customHeader = "x-Custom"
|
||||
req.Header.Set(customHeader, "val")
|
||||
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
const otherHeader = "X-Other"
|
||||
const otherHeader = "x-Other"
|
||||
req.Header.Set(otherHeader, "test")
|
||||
|
||||
pluginCtx := backend.PluginContext{
|
||||
@ -240,18 +185,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
})
|
||||
|
||||
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
|
||||
@ -262,18 +196,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader])
|
||||
})
|
||||
|
||||
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
|
||||
@ -284,27 +207,8 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Empty(t, reqClone.Header[customHeader])
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Request: req,
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
}, nil
|
||||
})
|
||||
|
@ -4,9 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
@ -16,8 +14,8 @@ import (
|
||||
const cookieHeaderName = "Cookie"
|
||||
|
||||
// NewCookiesMiddleware creates a new plugins.ClientMiddleware that will
|
||||
// forward incoming HTTP request Cookies to outgoing plugins.Client and
|
||||
// HTTP requests if the datasource has enabled forwarding of cookies (keepCookies).
|
||||
// forward incoming HTTP request Cookies to outgoing plugins.Client requests
|
||||
// if the datasource has enabled forwarding of cookies (keepCookies).
|
||||
func NewCookiesMiddleware(skipCookiesNames []string) plugins.ClientMiddleware {
|
||||
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
|
||||
return &CookiesMiddleware{
|
||||
@ -32,17 +30,17 @@ type CookiesMiddleware struct {
|
||||
skipCookiesNames []string
|
||||
}
|
||||
|
||||
func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
|
||||
func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) error {
|
||||
reqCtx := contexthandler.FromContext(ctx)
|
||||
// if request not for a datasource or no HTTP request context skip middleware
|
||||
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
|
||||
return ctx, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
settings := pCtx.DataSourceInstanceSettings
|
||||
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
return err
|
||||
}
|
||||
|
||||
ds := &datasources.DataSource{
|
||||
@ -54,20 +52,29 @@ func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.Plugi
|
||||
|
||||
proxyutil.ClearCookieHeader(reqCtx.Req, ds.AllowedCookies(), m.skipCookiesNames)
|
||||
|
||||
if cookieStr := reqCtx.Req.Header.Get(cookieHeaderName); cookieStr != "" {
|
||||
switch t := req.(type) {
|
||||
case *backend.QueryDataRequest:
|
||||
cookieStr := reqCtx.Req.Header.Get(cookieHeaderName)
|
||||
switch t := req.(type) {
|
||||
case *backend.QueryDataRequest:
|
||||
if cookieStr == "" {
|
||||
delete(t.Headers, cookieHeaderName)
|
||||
} else {
|
||||
t.Headers[cookieHeaderName] = cookieStr
|
||||
case *backend.CheckHealthRequest:
|
||||
}
|
||||
case *backend.CheckHealthRequest:
|
||||
if cookieStr == "" {
|
||||
delete(t.Headers, cookieHeaderName)
|
||||
} else {
|
||||
t.Headers[cookieHeaderName] = cookieStr
|
||||
case *backend.CallResourceRequest:
|
||||
}
|
||||
case *backend.CallResourceRequest:
|
||||
if cookieStr == "" {
|
||||
delete(t.Headers, cookieHeaderName)
|
||||
} else {
|
||||
t.Headers[cookieHeaderName] = []string{cookieStr}
|
||||
}
|
||||
}
|
||||
|
||||
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(reqCtx.Req.Cookies(), ds.AllowedCookies(), m.skipCookiesNames))
|
||||
|
||||
return ctx, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
@ -75,12 +82,12 @@ func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDat
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.next.QueryData(newCtx, req)
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
|
||||
@ -88,12 +95,12 @@ func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallR
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.next.CallResource(newCtx, req, sender)
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
|
||||
@ -101,12 +108,12 @@ func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckH
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
err := m.applyCookies(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.next.CheckHealth(newCtx, req)
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
func (m *CookiesMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
|
||||
|
@ -6,8 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -54,39 +52,19 @@ func TestCookiesMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
})
|
||||
|
||||
t.Run("Should not forward cookies when calling CallResource", func(t *testing.T) {
|
||||
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
|
||||
pReq := &backend.CallResourceRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string][]string{otherHeader: {"test"}},
|
||||
}, nopCallResourceSender)
|
||||
}
|
||||
pReq.Headers[backend.CookiesHeaderName] = []string{req.Header.Get(backend.CookiesHeaderName)}
|
||||
err = cdt.Decorator.CallResource(req.Context(), pReq, nopCallResourceSender)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
})
|
||||
|
||||
t.Run("Should not forward cookies when calling CheckHealth", func(t *testing.T) {
|
||||
@ -98,17 +76,6 @@ func TestCookiesMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 1)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
})
|
||||
})
|
||||
|
||||
@ -157,18 +124,6 @@ func TestCookiesMiddleware(t *testing.T) {
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 2)
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
require.EqualValues(t, "cookie2=", cdt.QueryDataReq.Headers[cookieHeaderName])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 2)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
|
||||
})
|
||||
|
||||
t.Run("Should forward cookies when calling CallResource", func(t *testing.T) {
|
||||
@ -182,18 +137,6 @@ func TestCookiesMiddleware(t *testing.T) {
|
||||
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
|
||||
require.Len(t, cdt.CallResourceReq.Headers[cookieHeaderName], 1)
|
||||
require.EqualValues(t, "cookie2=", cdt.CallResourceReq.Headers[cookieHeaderName][0])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 2)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
|
||||
})
|
||||
|
||||
t.Run("Should forward cookies when calling CheckHealth", func(t *testing.T) {
|
||||
@ -206,18 +149,6 @@ func TestCookiesMiddleware(t *testing.T) {
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 2)
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
require.EqualValues(t, "cookie2=", cdt.CheckHealthReq.Headers[cookieHeaderName])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 2)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -0,0 +1,108 @@
|
||||
package clientmiddleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
)
|
||||
|
||||
const forwardPluginRequestHTTPHeaders = "forward-plugin-request-http-headers"
|
||||
|
||||
// NewHTTPClientMiddleware creates a new plugins.ClientMiddleware
|
||||
// that will forward plugin request headers as outgoing HTTP headers.
|
||||
func NewHTTPClientMiddleware() plugins.ClientMiddleware {
|
||||
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
|
||||
return &HTTPClientMiddleware{
|
||||
next: next,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type HTTPClientMiddleware struct {
|
||||
next plugins.Client
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) applyHeaders(ctx context.Context, pReq interface{}) context.Context {
|
||||
if pReq == nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
mw := httpclient.NamedMiddlewareFunc(forwardPluginRequestHTTPHeaders, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
switch t := pReq.(type) {
|
||||
case *backend.QueryDataRequest:
|
||||
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
|
||||
req.Header.Set(ngalertmodels.FromAlertHeaderName, val)
|
||||
}
|
||||
case *backend.CallResourceRequest:
|
||||
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
|
||||
req.Header.Set(ngalertmodels.FromAlertHeaderName, val[0])
|
||||
}
|
||||
case *backend.CheckHealthRequest:
|
||||
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
|
||||
req.Header.Set(ngalertmodels.FromAlertHeaderName, val)
|
||||
}
|
||||
}
|
||||
|
||||
if h, ok := pReq.(backend.ForwardHTTPHeaders); ok {
|
||||
for k, v := range h.GetHTTPHeaders() {
|
||||
req.Header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
})
|
||||
|
||||
return httpclient.WithContextualMiddleware(ctx, mw)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
if req == nil {
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.applyHeaders(ctx, req)
|
||||
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
|
||||
if req == nil {
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
ctx = m.applyHeaders(ctx, req)
|
||||
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
|
||||
if req == nil {
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.applyHeaders(ctx, req)
|
||||
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
|
||||
return m.next.CollectMetrics(ctx, req)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
|
||||
return m.next.SubscribeStream(ctx, req)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
|
||||
return m.next.PublishStream(ctx, req)
|
||||
}
|
||||
|
||||
func (m *HTTPClientMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
|
||||
return m.next.RunStream(ctx, req, sender)
|
||||
}
|
@ -0,0 +1,289 @@
|
||||
package clientmiddleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util/proxyutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPClientMiddleware(t *testing.T) {
|
||||
const otherHeader = "test"
|
||||
|
||||
t.Run("When no http headers in plugin request", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("And requests are for a datasource", func(t *testing.T) {
|
||||
cdt := clienttest.NewClientDecoratorTest(t,
|
||||
clienttest.WithReqContext(req, &user.SignedInUser{}),
|
||||
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
|
||||
)
|
||||
|
||||
pluginCtx := backend.PluginContext{
|
||||
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
|
||||
}
|
||||
|
||||
t.Run("Should not forward headers when calling QueryData", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string]string{otherHeader: "val"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward headers when calling CallResource", func(t *testing.T) {
|
||||
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string][]string{otherHeader: {"val"}},
|
||||
}, nopCallResourceSender)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string]string{otherHeader: "val"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("And requests are for an app", func(t *testing.T) {
|
||||
cdt := clienttest.NewClientDecoratorTest(t,
|
||||
clienttest.WithReqContext(req, &user.SignedInUser{}),
|
||||
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
|
||||
)
|
||||
|
||||
pluginCtx := backend.PluginContext{
|
||||
AppInstanceSettings: &backend.AppInstanceSettings{},
|
||||
}
|
||||
|
||||
t.Run("Should not forward headers when calling QueryData", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string]string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 0)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward headers when calling CallResource", func(t *testing.T) {
|
||||
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string][]string{},
|
||||
}, nopCallResourceSender)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 0)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: map[string]string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 0)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("When HTTP headers in plugin request", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
headers := map[string]string{
|
||||
ngalertmodels.FromAlertHeaderName: "true",
|
||||
backend.OAuthIdentityTokenHeaderName: "bearer token",
|
||||
backend.OAuthIdentityIDTokenHeaderName: "id-token",
|
||||
"http_" + proxyutil.UserHeaderName: "uname",
|
||||
backend.CookiesHeaderName: "cookie1=; cookie2=; cookie3=",
|
||||
otherHeader: "val",
|
||||
}
|
||||
|
||||
crHeaders := map[string][]string{}
|
||||
for k, v := range headers {
|
||||
crHeaders[k] = []string{v}
|
||||
}
|
||||
|
||||
t.Run("And requests are for a datasource", func(t *testing.T) {
|
||||
cdt := clienttest.NewClientDecoratorTest(t,
|
||||
clienttest.WithReqContext(req, &user.SignedInUser{}),
|
||||
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
|
||||
)
|
||||
|
||||
pluginCtx := backend.PluginContext{
|
||||
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
|
||||
}
|
||||
|
||||
t.Run("Should forward headers when calling QueryData", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: headers,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 6)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 5)
|
||||
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
|
||||
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
|
||||
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
|
||||
require.Len(t, reqClone.Cookies(), 3)
|
||||
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
|
||||
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
|
||||
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
|
||||
})
|
||||
|
||||
t.Run("Should forward headers when calling CallResource", func(t *testing.T) {
|
||||
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: crHeaders,
|
||||
}, nopCallResourceSender)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 6)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 5)
|
||||
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
|
||||
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
|
||||
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
|
||||
require.Len(t, reqClone.Cookies(), 3)
|
||||
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
|
||||
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
|
||||
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
|
||||
})
|
||||
|
||||
t.Run("Should forward headers when calling CheckHealth", func(t *testing.T) {
|
||||
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
|
||||
PluginContext: pluginCtx,
|
||||
Headers: headers,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 6)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 5)
|
||||
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
|
||||
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
|
||||
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
|
||||
require.Len(t, reqClone.Cookies(), 3)
|
||||
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
|
||||
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
|
||||
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Request: req,
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
}, nil
|
||||
})
|
@ -3,12 +3,9 @@ package clientmiddleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
@ -16,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
// NewOAuthTokenMiddleware creates a new plugins.ClientMiddleware that will
|
||||
// set OAuth token headers on outgoing plugins.Client and HTTP requests if
|
||||
// the datasource has enabled Forward OAuth Identity (oauthPassThru).
|
||||
// set OAuth token headers on outgoing plugins.Client requests if the
|
||||
// datasource has enabled Forward OAuth Identity (oauthPassThru).
|
||||
func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plugins.ClientMiddleware {
|
||||
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
|
||||
return &OAuthTokenMiddleware{
|
||||
@ -37,17 +34,17 @@ type OAuthTokenMiddleware struct {
|
||||
next plugins.Client
|
||||
}
|
||||
|
||||
func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
|
||||
func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) error {
|
||||
reqCtx := contexthandler.FromContext(ctx)
|
||||
// if request not for a datasource or no HTTP request context skip middleware
|
||||
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
|
||||
return ctx, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
settings := pCtx.DataSourceInstanceSettings
|
||||
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
return err
|
||||
}
|
||||
|
||||
ds := &datasources.DataSource{
|
||||
@ -84,19 +81,10 @@ func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.Plug
|
||||
t.Headers[idTokenHeaderName] = []string{idTokenHeader}
|
||||
}
|
||||
}
|
||||
|
||||
httpHeaders := http.Header{}
|
||||
httpHeaders.Set(tokenHeaderName, authorizationHeader)
|
||||
|
||||
if idTokenHeader != "" {
|
||||
httpHeaders.Set(idTokenHeaderName, idTokenHeader)
|
||||
}
|
||||
|
||||
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.SetHeadersMiddleware(httpHeaders))
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
@ -104,12 +92,12 @@ func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.Query
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
|
||||
err := m.applyToken(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.next.QueryData(newCtx, req)
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
|
||||
@ -117,12 +105,12 @@ func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.Ca
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
|
||||
err := m.applyToken(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.next.CallResource(newCtx, req, sender)
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
|
||||
@ -130,12 +118,12 @@ func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.Che
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
|
||||
err := m.applyToken(ctx, req.PluginContext, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.next.CheckHealth(newCtx, req)
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
func (m *OAuthTokenMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
|
||||
|
@ -6,8 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
@ -49,9 +47,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward OAuth Identity when calling CallResource", func(t *testing.T) {
|
||||
@ -63,9 +58,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
|
||||
t.Run("Should not forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
|
||||
@ -77,9 +69,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 0)
|
||||
})
|
||||
})
|
||||
|
||||
@ -125,19 +114,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
|
||||
require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName])
|
||||
require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 3)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
|
||||
})
|
||||
|
||||
t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) {
|
||||
@ -153,19 +129,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0])
|
||||
require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1)
|
||||
require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 3)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
|
||||
})
|
||||
|
||||
t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
|
||||
@ -179,19 +142,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
|
||||
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
|
||||
require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName])
|
||||
require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName])
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
|
||||
reqClone := req.Clone(req.Context())
|
||||
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, res.Body.Close())
|
||||
require.Len(t, reqClone.Header, 3)
|
||||
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
|
||||
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
|
||||
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -2,19 +2,15 @@ package clientmiddleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/util/proxyutil"
|
||||
)
|
||||
|
||||
// NewUserHeaderMiddleware creates a new plugins.ClientMiddleware that will
|
||||
// populate the X-Grafana-User header on outgoing plugins.Client and HTTP
|
||||
// requests.
|
||||
// populate the X-Grafana-User header on outgoing plugins.Client requests.
|
||||
func NewUserHeaderMiddleware() plugins.ClientMiddleware {
|
||||
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
|
||||
return &UserHeaderMiddleware{
|
||||
@ -27,33 +23,17 @@ type UserHeaderMiddleware struct {
|
||||
next plugins.Client
|
||||
}
|
||||
|
||||
func (m *UserHeaderMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, h backend.ForwardHTTPHeaders) context.Context {
|
||||
func (m *UserHeaderMiddleware) applyUserHeader(ctx context.Context, h backend.ForwardHTTPHeaders) {
|
||||
reqCtx := contexthandler.FromContext(ctx)
|
||||
// if no HTTP request context skip middleware
|
||||
if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
|
||||
return ctx
|
||||
return
|
||||
}
|
||||
|
||||
h.DeleteHTTPHeader(proxyutil.UserHeaderName)
|
||||
if !reqCtx.IsAnonymous {
|
||||
h.SetHTTPHeader(proxyutil.UserHeaderName, reqCtx.Login)
|
||||
}
|
||||
|
||||
middlewares := []sdkhttpclient.Middleware{}
|
||||
|
||||
if !reqCtx.IsAnonymous {
|
||||
httpHeaders := http.Header{
|
||||
proxyutil.UserHeaderName: []string{reqCtx.Login},
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, httpclientprovider.SetHeadersMiddleware(httpHeaders))
|
||||
} else {
|
||||
middlewares = append(middlewares, httpclientprovider.DeleteHeadersMiddleware(proxyutil.UserHeaderName))
|
||||
}
|
||||
|
||||
ctx = sdkhttpclient.WithContextualMiddleware(ctx, middlewares...)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
@ -61,7 +41,7 @@ func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.Query
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.applyToken(ctx, req.PluginContext, req)
|
||||
m.applyUserHeader(ctx, req)
|
||||
|
||||
return m.next.QueryData(ctx, req)
|
||||
}
|
||||
@ -71,7 +51,7 @@ func (m *UserHeaderMiddleware) CallResource(ctx context.Context, req *backend.Ca
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
|
||||
ctx = m.applyToken(ctx, req.PluginContext, req)
|
||||
m.applyUserHeader(ctx, req)
|
||||
|
||||
return m.next.CallResource(ctx, req, sender)
|
||||
}
|
||||
@ -81,7 +61,7 @@ func (m *UserHeaderMiddleware) CheckHealth(ctx context.Context, req *backend.Che
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
||||
ctx = m.applyToken(ctx, req.PluginContext, req)
|
||||
m.applyUserHeader(ctx, req)
|
||||
|
||||
return m.next.CheckHealth(ctx, req)
|
||||
}
|
||||
|
@ -5,8 +5,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
|
||||
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util/proxyutil"
|
||||
@ -39,10 +37,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Empty(t, cdt.QueryDataReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should not forward user header when calling CallResource", func(t *testing.T) {
|
||||
@ -53,10 +47,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Empty(t, cdt.CallResourceReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) {
|
||||
@ -67,10 +57,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Empty(t, cdt.CheckHealthReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
})
|
||||
|
||||
@ -95,10 +81,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Empty(t, cdt.QueryDataReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should not forward user header when calling CallResource", func(t *testing.T) {
|
||||
@ -109,10 +91,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Empty(t, cdt.CallResourceReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) {
|
||||
@ -123,10 +101,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Empty(t, cdt.CheckHealthReq.Headers)
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -156,10 +130,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should forward user header when calling CallResource", func(t *testing.T) {
|
||||
@ -171,10 +141,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) {
|
||||
@ -186,10 +152,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
})
|
||||
|
||||
@ -214,10 +176,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.QueryDataReq)
|
||||
require.Len(t, cdt.QueryDataReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should forward user header when calling CallResource", func(t *testing.T) {
|
||||
@ -229,10 +187,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CallResourceReq)
|
||||
require.Len(t, cdt.CallResourceReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) {
|
||||
@ -244,10 +198,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
|
||||
require.NotNil(t, cdt.CheckHealthReq)
|
||||
require.Len(t, cdt.CheckHealthReq.Headers, 1)
|
||||
require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName))
|
||||
|
||||
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
|
||||
require.Len(t, middlewares, 1)
|
||||
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -81,5 +81,7 @@ func CreateMiddlewares(cfg *setting.Cfg, oAuthTokenService oauthtoken.OAuthToken
|
||||
middlewares = append(middlewares, clientmiddleware.NewUserHeaderMiddleware())
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, clientmiddleware.NewHTTPClientMiddleware())
|
||||
|
||||
return middlewares
|
||||
}
|
||||
|
@ -26,92 +26,226 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const loginCookieName = "grafana_session"
|
||||
|
||||
func TestIntegrationBackendPlugins(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
regularQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
|
||||
t.Helper()
|
||||
|
||||
return metricRequestWithQueries(t, fmt.Sprintf(`{
|
||||
"datasource": {
|
||||
"uid": "%s"
|
||||
}
|
||||
}`, tsCtx.uid))
|
||||
oauthToken := &oauth2.Token{
|
||||
TokenType: "bearer",
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
Expiry: time.Now().UTC().Add(24 * time.Hour),
|
||||
}
|
||||
oauthToken = oauthToken.WithExtra(map[string]interface{}{"id_token": "id-token"})
|
||||
|
||||
expressionQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
|
||||
t.Helper()
|
||||
newTestScenario(t, "Datasource with no custom HTTP settings",
|
||||
options(
|
||||
withIncomingRequest(func(req *http.Request) {
|
||||
req.Header.Set("X-Custom", "custom")
|
||||
req.AddCookie(&http.Cookie{Name: "cookie1"})
|
||||
req.AddCookie(&http.Cookie{Name: "cookie2"})
|
||||
req.AddCookie(&http.Cookie{Name: "cookie3"})
|
||||
req.AddCookie(&http.Cookie{Name: loginCookieName})
|
||||
}),
|
||||
),
|
||||
func(t *testing.T, tsCtx *testScenarioContext) {
|
||||
verify := func(h backend.ForwardHTTPHeaders) {
|
||||
require.NotNil(t, h)
|
||||
require.Empty(t, h.GetHTTPHeader(backend.CookiesHeaderName))
|
||||
require.Empty(t, h.GetHTTPHeader("Authorization"))
|
||||
|
||||
return metricRequestWithQueries(t, fmt.Sprintf(`{
|
||||
"refId": "A",
|
||||
"datasource": {
|
||||
"uid": "%s",
|
||||
"type": "%s"
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.NotEmpty(t, tsCtx.outgoingRequest.Header)
|
||||
}
|
||||
}`, tsCtx.uid, tsCtx.testPluginID), `{
|
||||
"refId": "B",
|
||||
"datasource": {
|
||||
"type": "__expr__",
|
||||
"uid": "__expr__",
|
||||
"name": "Expression"
|
||||
},
|
||||
"type": "math",
|
||||
"expression": "$A - 50"
|
||||
}`)
|
||||
}
|
||||
|
||||
newTestScenario(t, "When oauth token not available", func(t *testing.T, tsCtx *testScenarioContext) {
|
||||
tsCtx.testEnv.OAuthTokenService.Token = nil
|
||||
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
|
||||
verify(pReq)
|
||||
})
|
||||
|
||||
tsCtx.runCheckHealthTest(t)
|
||||
tsCtx.runCallResourceTest(t)
|
||||
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
|
||||
verify(pReq)
|
||||
require.Equal(t, "custom", pReq.GetHTTPHeader("X-Custom"))
|
||||
require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom"))
|
||||
})
|
||||
|
||||
t.Run("regular query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
|
||||
verifyQueryData := func(pReq *backend.QueryDataRequest) {
|
||||
verify(pReq)
|
||||
}
|
||||
|
||||
t.Run("regular query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
|
||||
|
||||
t.Run("expression query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("expression query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
|
||||
})
|
||||
})
|
||||
newTestScenario(t, "Datasource with most HTTP settings set except oauthPassThru and oauth token available",
|
||||
options(
|
||||
withIncomingRequest(func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{Name: "cookie1"})
|
||||
req.AddCookie(&http.Cookie{Name: "cookie2"})
|
||||
req.AddCookie(&http.Cookie{Name: "cookie3"})
|
||||
req.AddCookie(&http.Cookie{Name: loginCookieName})
|
||||
}),
|
||||
withOAuthToken(oauthToken),
|
||||
withDsBasicAuth("basicAuthUser", "basicAuthPassword"),
|
||||
withDsCustomHeader(map[string]string{"X-CUSTOM-HEADER": "custom-header-value"}),
|
||||
withDsCookieForwarding([]string{"cookie1", "cookie3", loginCookieName}),
|
||||
),
|
||||
func(t *testing.T, tsCtx *testScenarioContext) {
|
||||
verify := func(h backend.ForwardHTTPHeaders) {
|
||||
require.NotNil(t, h)
|
||||
require.Equal(t, "cookie1=; cookie3=", h.GetHTTPHeader(backend.CookiesHeaderName))
|
||||
require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName))
|
||||
|
||||
newTestScenario(t, "When oauth token available", func(t *testing.T, tsCtx *testScenarioContext) {
|
||||
token := &oauth2.Token{
|
||||
TokenType: "bearer",
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
Expiry: time.Now().UTC().Add(24 * time.Hour),
|
||||
}
|
||||
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
|
||||
tsCtx.testEnv.OAuthTokenService.Token = token
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get(backend.CookiesHeaderName))
|
||||
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
|
||||
|
||||
tsCtx.runCheckHealthTest(t)
|
||||
tsCtx.runCallResourceTest(t)
|
||||
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "basicAuthUser", username)
|
||||
require.Equal(t, "basicAuthPassword", pwd)
|
||||
}
|
||||
|
||||
t.Run("regular query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
|
||||
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
|
||||
verify(pReq)
|
||||
})
|
||||
|
||||
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
|
||||
verify(pReq)
|
||||
})
|
||||
|
||||
verifyQueryData := func(pReq *backend.QueryDataRequest) {
|
||||
verify(pReq)
|
||||
}
|
||||
|
||||
t.Run("regular query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
|
||||
|
||||
t.Run("expression query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("expression query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
|
||||
newTestScenario(t, "Datasource with oauthPassThru and basic auth configured and oauth token available",
|
||||
options(
|
||||
withOAuthToken(oauthToken),
|
||||
withDsOAuthForwarding(),
|
||||
withDsBasicAuth("basicAuthUser", "basicAuthPassword"),
|
||||
),
|
||||
func(t *testing.T, tsCtx *testScenarioContext) {
|
||||
verify := func(h backend.ForwardHTTPHeaders) {
|
||||
require.NotNil(t, h)
|
||||
|
||||
expectedAuthHeader := fmt.Sprintf("Bearer %s", oauthToken.AccessToken)
|
||||
expectedTokenHeader := oauthToken.Extra("id_token").(string)
|
||||
|
||||
require.Equal(t, expectedAuthHeader, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Equal(t, expectedTokenHeader, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName))
|
||||
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityTokenHeaderName))
|
||||
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
|
||||
}
|
||||
|
||||
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
|
||||
verify(pReq)
|
||||
})
|
||||
|
||||
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
|
||||
verify(pReq)
|
||||
})
|
||||
|
||||
verifyQueryData := func(pReq *backend.QueryDataRequest) {
|
||||
verify(pReq)
|
||||
}
|
||||
|
||||
t.Run("regular query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
|
||||
|
||||
t.Run("expression query", func(t *testing.T) {
|
||||
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type testScenarioContext struct {
|
||||
testPluginID string
|
||||
uid string
|
||||
grafanaListeningAddr string
|
||||
testEnv *server.TestEnv
|
||||
outgoingServer *httptest.Server
|
||||
outgoingRequest *http.Request
|
||||
backendTestPlugin *testPlugin
|
||||
rt http.RoundTripper
|
||||
testPluginID string
|
||||
uid string
|
||||
grafanaListeningAddr string
|
||||
testEnv *server.TestEnv
|
||||
outgoingServer *httptest.Server
|
||||
outgoingRequest *http.Request
|
||||
backendTestPlugin *testPlugin
|
||||
rt http.RoundTripper
|
||||
modifyIncomingRequest func(req *http.Request)
|
||||
}
|
||||
|
||||
func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx *testScenarioContext)) {
|
||||
type testScenarioInput struct {
|
||||
ds *datasources.AddDataSourceCommand
|
||||
token *oauth2.Token
|
||||
modifyIncomingRequest func(req *http.Request)
|
||||
}
|
||||
|
||||
type testScenarioOption func(*testScenarioInput)
|
||||
|
||||
func options(opts ...testScenarioOption) []testScenarioOption {
|
||||
return opts
|
||||
}
|
||||
|
||||
func withIncomingRequest(cb func(req *http.Request)) testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
in.modifyIncomingRequest = cb
|
||||
}
|
||||
}
|
||||
|
||||
func withOAuthToken(token *oauth2.Token) testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
in.token = token
|
||||
}
|
||||
}
|
||||
|
||||
func withDsBasicAuth(username, password string) testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
in.ds.BasicAuth = true
|
||||
in.ds.BasicAuthUser = username
|
||||
in.ds.SecureJsonData["basicAuthPassword"] = password
|
||||
}
|
||||
}
|
||||
|
||||
func withDsCustomHeader(headers map[string]string) testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
index := 1
|
||||
for k, v := range headers {
|
||||
in.ds.JsonData.Set(fmt.Sprintf("httpHeaderName%d", index), k)
|
||||
in.ds.SecureJsonData[fmt.Sprintf("httpHeaderValue%d", index)] = v
|
||||
index++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withDsOAuthForwarding() testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
in.ds.JsonData.Set("oauthPassThru", true)
|
||||
}
|
||||
}
|
||||
|
||||
func withDsCookieForwarding(names []string) testScenarioOption {
|
||||
return func(in *testScenarioInput) {
|
||||
in.ds.JsonData.Set("keepCookies", names)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestScenario(t *testing.T, name string, opts []testScenarioOption, callback func(t *testing.T, ctx *testScenarioContext)) {
|
||||
tsCtx := testScenarioContext{
|
||||
testPluginID: "test-plugin",
|
||||
}
|
||||
@ -123,6 +257,7 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
|
||||
|
||||
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
|
||||
tsCtx.grafanaListeningAddr = grafanaListeningAddr
|
||||
testEnv.SQLStore.Cfg.LoginCookieName = loginCookieName
|
||||
tsCtx.testEnv = testEnv
|
||||
ctx := context.Background()
|
||||
|
||||
@ -143,29 +278,30 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
|
||||
err := testEnv.PluginRegistry.Add(ctx, testPlugin)
|
||||
require.NoError(t, err)
|
||||
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"httpHeaderName1": "X-CUSTOM-HEADER",
|
||||
"oauthPassThru": true,
|
||||
"keepCookies": []string{"cookie1", "cookie3", "grafana_session"},
|
||||
})
|
||||
secureJSONData := map[string]string{
|
||||
"basicAuthPassword": "basicAuthPassword",
|
||||
"httpHeaderValue1": "custom-header-value",
|
||||
}
|
||||
jsonData := simplejson.New()
|
||||
secureJSONData := map[string]string{}
|
||||
|
||||
tsCtx.uid = "test-plugin"
|
||||
err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, &datasources.AddDataSourceCommand{
|
||||
cmd := &datasources.AddDataSourceCommand{
|
||||
OrgId: 1,
|
||||
Access: datasources.DS_ACCESS_PROXY,
|
||||
Name: "TestPlugin",
|
||||
Type: tsCtx.testPluginID,
|
||||
Uid: tsCtx.uid,
|
||||
Url: tsCtx.outgoingServer.URL,
|
||||
BasicAuth: true,
|
||||
BasicAuthUser: "basicAuthUser",
|
||||
JsonData: jsonData,
|
||||
SecureJsonData: secureJSONData,
|
||||
})
|
||||
}
|
||||
|
||||
in := &testScenarioInput{ds: cmd}
|
||||
for _, opt := range opts {
|
||||
opt(in)
|
||||
}
|
||||
|
||||
tsCtx.modifyIncomingRequest = in.modifyIncomingRequest
|
||||
tsCtx.testEnv.OAuthTokenService.Token = in.token
|
||||
|
||||
err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
getDataSourceQuery := &datasources.GetDataSourceQuery{
|
||||
@ -185,17 +321,42 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
|
||||
})
|
||||
}
|
||||
|
||||
func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest) {
|
||||
t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) {
|
||||
var received *struct {
|
||||
ctx context.Context
|
||||
req *backend.QueryDataRequest
|
||||
func createRegularQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
|
||||
t.Helper()
|
||||
|
||||
return metricRequestWithQueries(t, fmt.Sprintf(`{
|
||||
"datasource": {
|
||||
"uid": "%s"
|
||||
}
|
||||
}`, tsCtx.uid))
|
||||
}
|
||||
|
||||
func createExpressionQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
|
||||
t.Helper()
|
||||
|
||||
return metricRequestWithQueries(t, fmt.Sprintf(`{
|
||||
"refId": "A",
|
||||
"datasource": {
|
||||
"uid": "%s",
|
||||
"type": "%s"
|
||||
}
|
||||
}`, tsCtx.uid, tsCtx.testPluginID), `{
|
||||
"refId": "B",
|
||||
"datasource": {
|
||||
"type": "__expr__",
|
||||
"uid": "__expr__",
|
||||
"name": "Expression"
|
||||
},
|
||||
"type": "math",
|
||||
"expression": "$A - 50"
|
||||
}`)
|
||||
}
|
||||
|
||||
func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest, callback func(req *backend.QueryDataRequest)) {
|
||||
t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) {
|
||||
var received *backend.QueryDataRequest
|
||||
tsCtx.backendTestPlugin.QueryDataHandler = backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
received = &struct {
|
||||
ctx context.Context
|
||||
req *backend.QueryDataRequest
|
||||
}{ctx, req}
|
||||
received = req
|
||||
|
||||
c := http.Client{
|
||||
Transport: tsCtx.rt,
|
||||
@ -227,18 +388,11 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, u, buf1)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie1",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie2",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie3",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "grafana_session",
|
||||
})
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
|
||||
|
||||
if tsCtx.modifyIncomingRequest != nil {
|
||||
tsCtx.modifyIncomingRequest(req)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
@ -253,51 +407,18 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// backend query data request
|
||||
require.NotNil(t, received)
|
||||
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
|
||||
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
|
||||
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
|
||||
|
||||
token := tsCtx.testEnv.OAuthTokenService.Token
|
||||
|
||||
var expectedAuthHeader string
|
||||
var expectedTokenHeader string
|
||||
|
||||
if token != nil {
|
||||
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
|
||||
expectedTokenHeader = token.Extra("id_token").(string)
|
||||
|
||||
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
|
||||
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
|
||||
}
|
||||
|
||||
// outgoing HTTP request
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
|
||||
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
|
||||
|
||||
if token == nil {
|
||||
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "basicAuthUser", username)
|
||||
require.Equal(t, "basicAuthPassword", pwd)
|
||||
} else {
|
||||
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
|
||||
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
|
||||
}
|
||||
callback(received)
|
||||
})
|
||||
}
|
||||
|
||||
func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
|
||||
func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T, callback func(req *backend.CheckHealthRequest)) {
|
||||
t.Run("When calling /api/datasources/uid/:uid/health should set expected headers on outgoing CheckHealth and HTTP request", func(t *testing.T) {
|
||||
var received *struct {
|
||||
ctx context.Context
|
||||
req *backend.CheckHealthRequest
|
||||
}
|
||||
var received *backend.CheckHealthRequest
|
||||
tsCtx.backendTestPlugin.CheckHealthHandler = backend.CheckHealthHandlerFunc(func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
|
||||
received = &struct {
|
||||
ctx context.Context
|
||||
req *backend.CheckHealthRequest
|
||||
}{ctx, req}
|
||||
received = req
|
||||
|
||||
c := http.Client{
|
||||
Transport: tsCtx.rt,
|
||||
@ -328,18 +449,11 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie1",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie2",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie3",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "grafana_session",
|
||||
})
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
|
||||
|
||||
if tsCtx.modifyIncomingRequest != nil {
|
||||
tsCtx.modifyIncomingRequest(req)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
@ -354,51 +468,18 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// backend query data request
|
||||
require.NotNil(t, received)
|
||||
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
|
||||
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
|
||||
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
|
||||
|
||||
token := tsCtx.testEnv.OAuthTokenService.Token
|
||||
|
||||
var expectedAuthHeader string
|
||||
var expectedTokenHeader string
|
||||
|
||||
if token != nil {
|
||||
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
|
||||
expectedTokenHeader = token.Extra("id_token").(string)
|
||||
|
||||
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
|
||||
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
|
||||
}
|
||||
|
||||
// outgoing HTTP request
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
|
||||
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
|
||||
|
||||
if token == nil {
|
||||
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "basicAuthUser", username)
|
||||
require.Equal(t, "basicAuthPassword", pwd)
|
||||
} else {
|
||||
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
|
||||
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
|
||||
}
|
||||
callback(received)
|
||||
})
|
||||
}
|
||||
|
||||
func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
|
||||
func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T, callback func(req *backend.CallResourceRequest)) {
|
||||
t.Run("When calling /api/datasources/uid/:uid/resources should set expected headers on outgoing CallResource and HTTP request", func(t *testing.T) {
|
||||
var received *struct {
|
||||
ctx context.Context
|
||||
req *backend.CallResourceRequest
|
||||
}
|
||||
var received *backend.CallResourceRequest
|
||||
tsCtx.backendTestPlugin.CallResourceHandler = backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
|
||||
received = &struct {
|
||||
ctx context.Context
|
||||
req *backend.CallResourceRequest
|
||||
}{ctx, req}
|
||||
received = req
|
||||
|
||||
c := http.Client{
|
||||
Transport: tsCtx.rt,
|
||||
@ -450,19 +531,11 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
|
||||
req.Header.Set("Connection", "X-Some-Conn-Header")
|
||||
req.Header.Set("X-Some-Conn-Header", "should be deleted")
|
||||
req.Header.Set("Proxy-Connection", "should be deleted")
|
||||
req.Header.Set("X-Custom", "custom")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie1",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie2",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "cookie3",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "grafana_session",
|
||||
})
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
|
||||
|
||||
if tsCtx.modifyIncomingRequest != nil {
|
||||
tsCtx.modifyIncomingRequest(req)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
@ -484,45 +557,18 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
|
||||
require.Empty(t, resp.Header.Get("Set-Cookie"))
|
||||
require.Equal(t, "should not be deleted", resp.Header.Get("X-Custom"))
|
||||
|
||||
// backend query data request
|
||||
require.NotNil(t, received)
|
||||
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"][0])
|
||||
require.Empty(t, received.req.Headers["Connection"])
|
||||
require.Empty(t, received.req.Headers["X-Some-Conn-Header"])
|
||||
require.Empty(t, received.req.Headers["Proxy-Connection"])
|
||||
require.Equal(t, "custom", received.req.Headers["X-Custom"][0])
|
||||
require.Empty(t, received.Headers["Connection"])
|
||||
require.Empty(t, received.Headers["X-Some-Conn-Header"])
|
||||
require.Empty(t, received.Headers["Proxy-Connection"])
|
||||
|
||||
token := tsCtx.testEnv.OAuthTokenService.Token
|
||||
|
||||
var expectedAuthHeader string
|
||||
var expectedTokenHeader string
|
||||
|
||||
if token != nil {
|
||||
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
|
||||
expectedTokenHeader = token.Extra("id_token").(string)
|
||||
|
||||
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"][0])
|
||||
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"][0])
|
||||
}
|
||||
|
||||
// outgoing HTTP request
|
||||
require.NotNil(t, tsCtx.outgoingRequest)
|
||||
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
|
||||
require.Empty(t, tsCtx.outgoingRequest.Header.Get("Connection"))
|
||||
require.Empty(t, tsCtx.outgoingRequest.Header.Get("X-Some-Conn-Header"))
|
||||
require.Empty(t, tsCtx.outgoingRequest.Header.Get("Proxy-Connection"))
|
||||
require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom"))
|
||||
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
|
||||
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
|
||||
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
|
||||
|
||||
if token == nil {
|
||||
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "basicAuthUser", username)
|
||||
require.Equal(t, "basicAuthPassword", pwd)
|
||||
} else {
|
||||
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
|
||||
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
|
||||
}
|
||||
callback(received)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/clients"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/models"
|
||||
@ -162,7 +163,7 @@ func (e *cloudWatchExecutor) QueryData(ctx context.Context, req *backend.QueryDa
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, fromAlert := req.Headers["FromAlert"]
|
||||
_, fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName]
|
||||
isLogAlertQuery := fromAlert && model.QueryMode == logsQueryMode
|
||||
|
||||
if isLogAlertQuery {
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/mocks"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/models"
|
||||
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/utils"
|
||||
@ -212,7 +213,7 @@ func Test_executeLogAlertQuery(t *testing.T) {
|
||||
executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures())
|
||||
|
||||
_, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{
|
||||
Headers: map[string]string{"FromAlert": "some value"},
|
||||
Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"},
|
||||
PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}},
|
||||
Queries: []backend.DataQuery{
|
||||
{
|
||||
@ -238,7 +239,7 @@ func Test_executeLogAlertQuery(t *testing.T) {
|
||||
|
||||
executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures())
|
||||
_, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{
|
||||
Headers: map[string]string{"FromAlert": "some value"},
|
||||
Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"},
|
||||
PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}},
|
||||
Queries: []backend.DataQuery{
|
||||
{
|
||||
|
@ -18,10 +18,9 @@ import (
|
||||
)
|
||||
|
||||
type LokiAPI struct {
|
||||
client *http.Client
|
||||
url string
|
||||
log log.Logger
|
||||
headers map[string]string
|
||||
client *http.Client
|
||||
url string
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
type RawLokiResponse struct {
|
||||
@ -29,17 +28,11 @@ type RawLokiResponse struct {
|
||||
Encoding string
|
||||
}
|
||||
|
||||
func newLokiAPI(client *http.Client, url string, log log.Logger, headers map[string]string) *LokiAPI {
|
||||
return &LokiAPI{client: client, url: url, log: log, headers: headers}
|
||||
func newLokiAPI(client *http.Client, url string, log log.Logger) *LokiAPI {
|
||||
return &LokiAPI{client: client, url: url, log: log}
|
||||
}
|
||||
|
||||
func addHeaders(req *http.Request, headers map[string]string) {
|
||||
for name, value := range headers {
|
||||
req.Header.Set(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, headers map[string]string) (*http.Request, error) {
|
||||
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*http.Request, error) {
|
||||
qs := url.Values{}
|
||||
qs.Set("query", query.Expr)
|
||||
|
||||
@ -92,8 +85,6 @@ func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, hea
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addHeaders(req, headers)
|
||||
|
||||
if query.VolumeQuery {
|
||||
req.Header.Set("X-Query-Tags", "Source=logvolhist")
|
||||
}
|
||||
@ -145,7 +136,7 @@ func makeLokiError(body io.ReadCloser) error {
|
||||
}
|
||||
|
||||
func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames, error) {
|
||||
req, err := makeDataRequest(ctx, api.url, query, api.headers)
|
||||
req, err := makeDataRequest(ctx, api.url, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -183,7 +174,7 @@ func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames
|
||||
return res.Frames, nil
|
||||
}
|
||||
|
||||
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string, headers map[string]string) (*http.Request, error) {
|
||||
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string) (*http.Request, error) {
|
||||
lokiUrl, err := url.Parse(lokiDsUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -204,13 +195,11 @@ func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addHeaders(req, headers)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (api *LokiAPI) RawQuery(ctx context.Context, resourcePath string) (RawLokiResponse, error) {
|
||||
req, err := makeRawRequest(ctx, api.url, resourcePath, api.headers)
|
||||
req, err := makeRawRequest(ctx, api.url, resourcePath)
|
||||
if err != nil {
|
||||
return RawLokiResponse{}, err
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func makeMockedAPIWithUrl(url string, statusCode int, contentType string, respon
|
||||
Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback},
|
||||
}
|
||||
|
||||
return newLokiAPI(&client, url, log.New("test"), nil)
|
||||
return newLokiAPI(&client, url, log.New("test"))
|
||||
}
|
||||
|
||||
func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType string, responseBytes []byte, requestCallback mockRequestCallback) *LokiAPI {
|
||||
@ -72,5 +72,5 @@ func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType stri
|
||||
Transport: &mockedCompressedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback},
|
||||
}
|
||||
|
||||
return newLokiAPI(&client, url, log.New("test"), nil)
|
||||
return newLokiAPI(&client, url, log.New("test"))
|
||||
}
|
||||
|
@ -1,201 +0,0 @@
|
||||
package loki
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
)
|
||||
|
||||
type mockedRoundTripperForOauth struct {
|
||||
requestCallback func(req *http.Request)
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (mockedRT *mockedRoundTripperForOauth) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
mockedRT.requestCallback(req)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(mockedRT.body)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockedCallResourceResponseSenderForOauth struct {
|
||||
Response *backend.CallResourceResponse
|
||||
}
|
||||
|
||||
func (s *mockedCallResourceResponseSenderForOauth) Send(resp *backend.CallResourceResponse) error {
|
||||
s.Response = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeMockedDsInfoForOauth(body []byte, requestCallback func(req *http.Request)) datasourceInfo {
|
||||
client := http.Client{
|
||||
Transport: &mockedRoundTripperForOauth{requestCallback: requestCallback, body: body},
|
||||
}
|
||||
|
||||
return datasourceInfo{
|
||||
HTTPClient: &client,
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthForwardIdentity(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
auth bool
|
||||
cookie bool
|
||||
}{
|
||||
{name: "when auth headers exist => add auth headers", auth: true, cookie: false},
|
||||
{name: "when cookie header exists => add cookie header", auth: false, cookie: true},
|
||||
{name: "when cookie&auth headers exist => add cookie&auth headers", auth: true, cookie: true},
|
||||
{name: "when no header exists => do not add headers", auth: false, cookie: false},
|
||||
}
|
||||
|
||||
authName := "Authorization"
|
||||
authValue := "auth"
|
||||
cookieName := "Cookie"
|
||||
cookieValue := "a=1"
|
||||
idTokenName := "X-ID-Token"
|
||||
idTokenValue := "idtoken"
|
||||
|
||||
for _, test := range tt {
|
||||
t.Run("QueryData: "+test.name, func(t *testing.T) {
|
||||
response := []byte(`
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"resultType": "streams",
|
||||
"result": [
|
||||
{
|
||||
"stream": {},
|
||||
"values": [
|
||||
["1", "line1"]
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
`)
|
||||
|
||||
clientUsed := false
|
||||
dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) {
|
||||
clientUsed = true
|
||||
// we need to check for "header does not exist",
|
||||
// and the only way i can find is to get the values
|
||||
// as an array
|
||||
authValues := req.Header.Values(authName)
|
||||
cookieValues := req.Header.Values(cookieName)
|
||||
idTokenValues := req.Header.Values(idTokenName)
|
||||
if test.auth {
|
||||
require.Equal(t, []string{authValue}, authValues)
|
||||
require.Equal(t, []string{idTokenValue}, idTokenValues)
|
||||
} else {
|
||||
require.Len(t, authValues, 0)
|
||||
require.Len(t, idTokenValues, 0)
|
||||
}
|
||||
if test.cookie {
|
||||
require.Equal(t, []string{cookieValue}, cookieValues)
|
||||
} else {
|
||||
require.Len(t, cookieValues, 0)
|
||||
}
|
||||
})
|
||||
|
||||
req := backend.QueryDataRequest{
|
||||
Headers: map[string]string{},
|
||||
Queries: []backend.DataQuery{
|
||||
{
|
||||
RefID: "A",
|
||||
JSON: []byte("{}"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if test.auth {
|
||||
req.Headers[authName] = authValue
|
||||
req.Headers[idTokenName] = idTokenValue
|
||||
}
|
||||
|
||||
if test.cookie {
|
||||
req.Headers[cookieName] = cookieValue
|
||||
}
|
||||
|
||||
tracer := tracing.InitializeTracerForTest()
|
||||
|
||||
data, err := queryData(context.Background(), &req, &dsInfo, tracer)
|
||||
// we do a basic check that the result is OK
|
||||
require.NoError(t, err)
|
||||
require.Len(t, data.Responses, 1)
|
||||
res := data.Responses["A"]
|
||||
require.NoError(t, res.Error)
|
||||
require.Len(t, res.Frames, 1)
|
||||
require.Equal(t, "line1", res.Frames[0].Fields[2].At(0))
|
||||
|
||||
// we need to be sure the client-callback was triggered
|
||||
require.True(t, clientUsed)
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tt {
|
||||
t.Run("CallResource: "+test.name, func(t *testing.T) {
|
||||
response := []byte("mocked resource response")
|
||||
|
||||
clientUsed := false
|
||||
dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) {
|
||||
clientUsed = true
|
||||
authValues := req.Header.Values(authName)
|
||||
cookieValues := req.Header.Values(cookieName)
|
||||
idTokenValues := req.Header.Values(idTokenName)
|
||||
// we need to check for "header does not exist",
|
||||
// and the only way i can find is to get the values
|
||||
// as an array
|
||||
if test.auth {
|
||||
require.Equal(t, []string{authValue}, authValues)
|
||||
require.Equal(t, []string{idTokenValue}, idTokenValues)
|
||||
} else {
|
||||
require.Len(t, authValues, 0)
|
||||
require.Len(t, idTokenValues, 0)
|
||||
}
|
||||
if test.cookie {
|
||||
require.Equal(t, []string{cookieValue}, cookieValues)
|
||||
} else {
|
||||
require.Len(t, cookieValues, 0)
|
||||
}
|
||||
})
|
||||
|
||||
req := backend.CallResourceRequest{
|
||||
Headers: map[string][]string{},
|
||||
Method: "GET",
|
||||
URL: "labels?",
|
||||
}
|
||||
|
||||
if test.auth {
|
||||
req.Headers[authName] = []string{authValue}
|
||||
req.Headers[idTokenName] = []string{idTokenValue}
|
||||
}
|
||||
if test.cookie {
|
||||
req.Headers[cookieName] = []string{cookieValue}
|
||||
}
|
||||
|
||||
sender := &mockedCallResourceResponseSenderForOauth{}
|
||||
|
||||
err := callResource(context.Background(), &req, sender, &dsInfo, log.New("testlog"))
|
||||
// we do a basic check that the result is OK
|
||||
require.NoError(t, err)
|
||||
sent := sender.Response
|
||||
require.NotNil(t, sent)
|
||||
require.Equal(t, http.StatusOK, sent.Status)
|
||||
require.Equal(t, response, sent.Body)
|
||||
|
||||
// we need to be sure the client-callback was triggered
|
||||
require.True(t, clientUsed)
|
||||
})
|
||||
}
|
||||
}
|
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -96,22 +95,6 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
|
||||
}
|
||||
}
|
||||
|
||||
// in the CallResource API, request-headers are in a map where the value is an array-of-strings,
|
||||
// so we need a helper function that can extract a single string-value from an array-of-strings.
|
||||
// i only deal with two cases:
|
||||
// - zero-length array
|
||||
// - first-item of the array
|
||||
// i do not handle the case where there are multiple items in the array, i do not know
|
||||
// if that can even happen ever, for the headers that we are interested in.
|
||||
func arrayHeaderFirstValue(values []string) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// NOTE: we assume there never is a second item in the http-header-values-array
|
||||
return values[0]
|
||||
}
|
||||
|
||||
func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
|
||||
dsInfo, err := s.getDSInfo(req.PluginContext)
|
||||
if err != nil {
|
||||
@ -120,30 +103,6 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq
|
||||
return callResource(ctx, req, sender, dsInfo, logger.FromContext(ctx))
|
||||
}
|
||||
|
||||
func getHeadersForCallResource(headers map[string][]string) map[string]string {
|
||||
data := make(map[string]string)
|
||||
|
||||
for k, values := range headers {
|
||||
k = textproto.CanonicalMIMEHeaderKey(k)
|
||||
firstValue := arrayHeaderFirstValue(values)
|
||||
|
||||
if firstValue == "" {
|
||||
continue
|
||||
}
|
||||
switch k {
|
||||
case "Authorization":
|
||||
data["Authorization"] = firstValue
|
||||
case "X-Id-Token":
|
||||
data["X-ID-Token"] = firstValue
|
||||
case "Cookie":
|
||||
data["Cookie"] = firstValue
|
||||
case "Accept-Encoding":
|
||||
data["Accept-Encoding"] = firstValue
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func callResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, dsInfo *datasourceInfo, plog log.Logger) error {
|
||||
url := req.URL
|
||||
|
||||
@ -158,7 +117,7 @@ func callResource(ctx context.Context, req *backend.CallResourceRequest, sender
|
||||
}
|
||||
lokiURL := fmt.Sprintf("/loki/api/v1/%s", url)
|
||||
|
||||
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getHeadersForCallResource(req.Headers))
|
||||
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog)
|
||||
encodedBytes, err := api.RawQuery(ctx, lokiURL)
|
||||
|
||||
if err != nil {
|
||||
@ -191,7 +150,7 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
|
||||
func queryData(ctx context.Context, req *backend.QueryDataRequest, dsInfo *datasourceInfo, tracer tracing.Tracer) (*backend.QueryDataResponse, error) {
|
||||
result := backend.NewQueryDataResponse()
|
||||
|
||||
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx), req.Headers)
|
||||
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx))
|
||||
|
||||
queries, err := parseQuery(req)
|
||||
if err != nil {
|
||||
|
@ -1,74 +0,0 @@
|
||||
package loki
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetHeadersForCallResource(t *testing.T) {
|
||||
const idTokn1 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
const idTokn2 = "eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU"
|
||||
const authTokn1 = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
const authTokn2 = "Bearer eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU"
|
||||
|
||||
testCases := map[string]struct {
|
||||
headers map[string][]string
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
"Headers with empty value": {
|
||||
headers: map[string][]string{
|
||||
"X-Grafana-Org-Id": {"1"},
|
||||
"Cookie": {""},
|
||||
"X-Id-Token": {""},
|
||||
"Accept-Encoding": {""},
|
||||
"Authorization": {""},
|
||||
},
|
||||
expectedHeaders: map[string]string{},
|
||||
},
|
||||
"Headers with multiple values": {
|
||||
headers: map[string][]string{
|
||||
"Authorization": {authTokn1, authTokn2},
|
||||
"Cookie": {"a=1"},
|
||||
"X-Grafana-Org-Id": {"1"},
|
||||
"Accept-Encoding": {"gzip", "compress"},
|
||||
"X-Id-Token": {idTokn1, idTokn2},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"Authorization": authTokn1,
|
||||
"Cookie": "a=1",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-ID-Token": idTokn1,
|
||||
},
|
||||
},
|
||||
"Headers with single value": {
|
||||
headers: map[string][]string{
|
||||
"Authorization": {authTokn1},
|
||||
"X-Grafana-Org-Id": {"1"},
|
||||
"Cookie": {"a=1"},
|
||||
"Accept-Encoding": {"gzip"},
|
||||
"X-Id-Token": {idTokn1},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"Authorization": authTokn1,
|
||||
"Cookie": "a=1",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-ID-Token": idTokn1,
|
||||
},
|
||||
},
|
||||
"Non Canonical 'X-Id-Token' header key": {
|
||||
headers: map[string][]string{
|
||||
"X-ID-TOKEN": {idTokn1},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-ID-Token": idTokn1,
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, test := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
headers := getHeadersForCallResource(test.headers)
|
||||
assert.Equal(t, test.expectedHeaders, headers)
|
||||
})
|
||||
}
|
||||
}
|
@ -32,7 +32,7 @@ func NewClient(d doer, method, baseUrl string) *Client {
|
||||
return &Client{doer: d, method: method, baseUrl: baseUrl}
|
||||
}
|
||||
|
||||
func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
|
||||
func (c *Client) QueryRange(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
tr := q.TimeRange()
|
||||
qv := map[string]string{
|
||||
"query": q.Expr,
|
||||
@ -41,7 +41,7 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H
|
||||
"step": strconv.FormatFloat(tr.Step.Seconds(), 'f', -1, 64),
|
||||
}
|
||||
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv, headers)
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -49,14 +49,14 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H
|
||||
return c.doer.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
|
||||
func (c *Client) QueryInstant(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
qv := map[string]string{"query": q.Expr}
|
||||
tr := q.TimeRange()
|
||||
if !tr.End.IsZero() {
|
||||
qv["time"] = formatTime(tr.End)
|
||||
}
|
||||
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query", qv, headers)
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query", qv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -64,7 +64,7 @@ func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http
|
||||
return c.doer.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
|
||||
func (c *Client) QueryExemplars(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
tr := q.TimeRange()
|
||||
qv := map[string]string{
|
||||
"query": q.Expr,
|
||||
@ -72,7 +72,7 @@ func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers ht
|
||||
"end": formatTime(tr.End),
|
||||
}
|
||||
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv, headers)
|
||||
req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -95,7 +95,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq
|
||||
|
||||
// We use method from the request, as for resources front end may do a fallback to GET if POST does not work
|
||||
// nad we want to respect that.
|
||||
httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body), req.Headers)
|
||||
httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -103,7 +103,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq
|
||||
return c.doer.Do(httpRequest)
|
||||
}
|
||||
|
||||
func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string, headers http.Header) (*http.Request, error) {
|
||||
func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string) (*http.Request, error) {
|
||||
if strings.ToUpper(c.method) == http.MethodPost {
|
||||
u, err := c.createUrl(endpoint, nil)
|
||||
if err != nil {
|
||||
@ -115,7 +115,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map
|
||||
v.Set(key, val)
|
||||
}
|
||||
|
||||
return createRequest(ctx, c.method, u, strings.NewReader(v.Encode()), headers)
|
||||
return createRequest(ctx, c.method, u, strings.NewReader(v.Encode()))
|
||||
}
|
||||
|
||||
u, err := c.createUrl(endpoint, qv)
|
||||
@ -123,7 +123,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return createRequest(ctx, c.method, u, http.NoBody, headers)
|
||||
return createRequest(ctx, c.method, u, http.NoBody)
|
||||
}
|
||||
|
||||
func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, error) {
|
||||
@ -148,15 +148,12 @@ func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, err
|
||||
return finalUrl, nil
|
||||
}
|
||||
|
||||
func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader, header http.Header) (*http.Request, error) {
|
||||
func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader) (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, method, u.String(), bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// request.Header is created empty from NewRequestWithContext so we can just replace it
|
||||
if header != nil {
|
||||
request.Header = header
|
||||
}
|
||||
|
||||
if strings.ToUpper(method) == http.MethodPost {
|
||||
// This may not be true but right now we don't have more information here and seems like we send just this type
|
||||
// of encoding right now if it is a POST
|
||||
|
@ -37,7 +37,6 @@ func TestClient(t *testing.T) {
|
||||
Path: "/api/v1/series",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/v1/series",
|
||||
Headers: nil,
|
||||
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
|
||||
}
|
||||
res, err := client.QueryResource(context.Background(), req)
|
||||
@ -63,7 +62,6 @@ func TestClient(t *testing.T) {
|
||||
Path: "/api/v1/series",
|
||||
Method: http.MethodGet,
|
||||
URL: "api/v1/series?match%5B%5D=ALERTS&start=1655272558&end=1655294158",
|
||||
Headers: nil,
|
||||
}
|
||||
res, err := client.QueryResource(context.Background(), req)
|
||||
defer func() {
|
||||
@ -95,7 +93,7 @@ func TestClient(t *testing.T) {
|
||||
RangeQuery: true,
|
||||
Step: 1 * time.Second,
|
||||
}
|
||||
res, err := client.QueryRange(context.Background(), req, nil)
|
||||
res, err := client.QueryRange(context.Background(), req)
|
||||
defer func() {
|
||||
if res != nil && res.Body != nil {
|
||||
if err := res.Body.Close(); err != nil {
|
||||
@ -122,7 +120,7 @@ func TestClient(t *testing.T) {
|
||||
RangeQuery: true,
|
||||
Step: 1 * time.Second,
|
||||
}
|
||||
res, err := client.QueryRange(context.Background(), req, nil)
|
||||
res, err := client.QueryRange(context.Background(), req)
|
||||
defer func() {
|
||||
if res != nil && res.Body != nil {
|
||||
if err := res.Body.Close(); err != nil {
|
||||
|
@ -85,11 +85,7 @@ func TestService(t *testing.T) {
|
||||
Path: "/api/v1/series",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/v1/series",
|
||||
// This header should be passed on to the resource request
|
||||
Headers: map[string][]string{
|
||||
"foo": {"bar"},
|
||||
},
|
||||
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
|
||||
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
|
||||
}
|
||||
|
||||
sender := &fakeSender{}
|
||||
@ -100,7 +96,6 @@ func TestService(t *testing.T) {
|
||||
http.Header{
|
||||
"Content-Type": {"application/x-www-form-urlencoded"},
|
||||
"Idempotency-Key": []string(nil),
|
||||
"foo": {"bar"},
|
||||
},
|
||||
httpProvider.Roundtripper.Req.Header)
|
||||
require.Equal(t, http.MethodPost, httpProvider.Roundtripper.Req.Method)
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
|
||||
"github.com/grafana/grafana/pkg/tsdb/intervalv2"
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/models"
|
||||
@ -79,7 +80,7 @@ func New(
|
||||
}
|
||||
|
||||
func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
|
||||
fromAlert := req.Headers["FromAlert"] == "true"
|
||||
fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName] == "true"
|
||||
result := backend.QueryDataResponse{
|
||||
Responses: backend.Responses{},
|
||||
}
|
||||
@ -89,7 +90,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest)
|
||||
if err != nil {
|
||||
return &result, err
|
||||
}
|
||||
r, err := s.fetch(ctx, s.client, query, req.Headers)
|
||||
r, err := s.fetch(ctx, s.client, query)
|
||||
if err != nil {
|
||||
return &result, err
|
||||
}
|
||||
@ -103,7 +104,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
|
||||
func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query) (*backend.DataResponse, error) {
|
||||
traceCtx, end := s.trace(ctx, q)
|
||||
defer end()
|
||||
|
||||
@ -116,7 +117,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
|
||||
}
|
||||
|
||||
if q.InstantQuery {
|
||||
res, err := s.instantQuery(traceCtx, client, q, headers)
|
||||
res, err := s.instantQuery(traceCtx, client, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -125,7 +126,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
|
||||
}
|
||||
|
||||
if q.RangeQuery {
|
||||
res, err := s.rangeQuery(traceCtx, client, q, headers)
|
||||
res, err := s.rangeQuery(traceCtx, client, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -140,7 +141,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
|
||||
}
|
||||
|
||||
if q.ExemplarQuery {
|
||||
res, err := s.exemplarQuery(traceCtx, client, q, headers)
|
||||
res, err := s.exemplarQuery(traceCtx, client, q)
|
||||
if err != nil {
|
||||
// If exemplar query returns error, we want to only log it and
|
||||
// continue with other results processing
|
||||
@ -154,24 +155,24 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryRange(ctx, q, sdkHeaderToHttpHeader(headers))
|
||||
func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryRange(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.parseResponse(ctx, q, res)
|
||||
}
|
||||
|
||||
func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryInstant(ctx, q, sdkHeaderToHttpHeader(headers))
|
||||
func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryInstant(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.parseResponse(ctx, q, res)
|
||||
}
|
||||
|
||||
func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryExemplars(ctx, q, sdkHeaderToHttpHeader(headers))
|
||||
func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
|
||||
res, err := c.QueryExemplars(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -185,11 +186,3 @@ func (s *QueryData) trace(ctx context.Context, q *models.Query) (context.Context
|
||||
{Key: "stop_unixnano", Value: q.End, Kv: attribute.Key("stop_unixnano").Int64(q.End.UnixNano())},
|
||||
})
|
||||
}
|
||||
|
||||
func sdkHeaderToHttpHeader(headers map[string]string) http.Header {
|
||||
httpHeader := make(http.Header, len(headers))
|
||||
for key, val := range headers {
|
||||
httpHeader.Set(key, val)
|
||||
}
|
||||
return httpHeader
|
||||
}
|
||||
|
@ -16,7 +16,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
|
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
|
||||
p "github.com/prometheus/common/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
@ -343,27 +342,6 @@ func TestPrometheus_parseTimeSeriesResponse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrometheusCanonicalHeaders(t *testing.T) {
|
||||
// Ensure headers are always canonicalized for all outgoing requests
|
||||
b, err := json.Marshal(models.QueryModel{})
|
||||
require.NoError(t, err)
|
||||
query := backend.DataQuery{JSON: b}
|
||||
tctx, err := setup(true)
|
||||
require.NoError(t, err)
|
||||
const idToken = "abc"
|
||||
_, err = executeWithHeaders(tctx, query, queryResult{}, map[string]string{
|
||||
"X-Id-Token": idToken,
|
||||
"X-ID-Token": idToken,
|
||||
"X-Other": "thing",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tctx.httpProvider.req.Header)
|
||||
// Check the request that hit the fake prometheus server to ensure headers are valid
|
||||
assert.Equal(t, []string{idToken}, tctx.httpProvider.req.Header["X-Id-Token"])
|
||||
assert.Empty(t, tctx.httpProvider.req.Header["X-ID-Token"]) //nolint:staticcheck
|
||||
assert.Equal(t, []string{"thing"}, tctx.httpProvider.req.Header["X-Other"])
|
||||
}
|
||||
|
||||
type queryResult struct {
|
||||
Type p.ValueType `json:"resultType"`
|
||||
Result interface{} `json:"result"`
|
||||
|
Loading…
Reference in New Issue
Block a user