Plugins: Automatically forward plugin request HTTP headers in outgoing HTTP requests (#60417)

Automatically forward core plugin request HTTP headers in outgoing HTTP requests. 
Core datasource plugin authors don't have to specifically handle forwarding of HTTP 
headers, e.g. do not have to "hardcode" the header-names in the datasource plugin, 
if not having custom needs.

Fixes #57065
This commit is contained in:
Marcus Efraimsson 2022-12-21 13:25:58 +01:00 committed by GitHub
parent aaab477594
commit c35c689a96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 816 additions and 1194 deletions

View File

@ -1,27 +0,0 @@
package httpclientprovider
import (
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
const DeleteHeadersMiddlewareName = "delete-headers"
// DeleteHeadersMiddleware middleware that delete headers on the outgoing
// request if header names provided.
func DeleteHeadersMiddleware(headerNames ...string) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(DeleteHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if len(headerNames) == 0 {
return next
}
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
for _, k := range headerNames {
req.Header.Del(k)
}
return next.RoundTrip(req)
})
})
}

View File

@ -1,66 +0,0 @@
package httpclientprovider
import (
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/require"
)
func TestDeleteHeadersMiddleware(t *testing.T) {
t.Run("Without headerNames should return next http.RoundTripper", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("finalrt")
var headerNames []string
mw := DeleteHeadersMiddleware(headerNames...)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
})
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("final")
headerNames := []string{"X-Header-B", "X-Header-C"}
mw := DeleteHeadersMiddleware(headerNames...)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
req.Header.Set("X-Header-A", "a")
req.Header.Set("X-Header-B", "b")
req.Header.Set("X-Header-C", "c")
req.Header.Set("X-Header-D", "d")
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
require.Equal(t, "a", req.Header.Get("X-Header-A"))
require.Empty(t, req.Header.Get("X-Header-B"))
require.Empty(t, req.Header.Get("X-Header-C"))
require.Equal(t, "d", req.Header.Get("X-Header-D"))
})
}

View File

@ -1,31 +0,0 @@
package httpclientprovider
import (
"net/http"
"net/textproto"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
const SetHeadersMiddlewareName = "set-headers"
// SetHeadersMiddleware middleware that sets headers on the outgoing
// request if headers provided.
// If the request already contains any of the headers provided, they
// will be overwritten.
func SetHeadersMiddleware(headers http.Header) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(SetHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if len(headers) == 0 {
return next
}
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
for k, v := range headers {
canonicalKey := textproto.CanonicalMIMEHeaderKey(k)
req.Header[canonicalKey] = v
}
return next.RoundTrip(req)
})
})
}

View File

@ -1,66 +0,0 @@
package httpclientprovider
import (
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/require"
)
func TestSetHeadersMiddleware(t *testing.T) {
t.Run("Without headers set should return next http.RoundTripper", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("finalrt")
var headers http.Header
mw := SetHeadersMiddleware(headers)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
})
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("final")
headers := http.Header{
"X-Header-A": []string{"value a"},
"X-Header-B": []string{"value b"},
"X-HEader-C": []string{"value c"},
}
mw := SetHeadersMiddleware(headers)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
req.Header.Set("X-Header-B", "d")
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
require.Equal(t, "value a", req.Header.Get("X-Header-A"))
require.Equal(t, "value b", req.Header.Get("X-Header-B"))
require.Equal(t, "value c", req.Header.Get("X-Header-C"))
})
}

View File

@ -255,6 +255,7 @@ var hopHeaders = []string{
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
"User-Agent",
}
// removeHopByHopHeaders removes hop-by-hop headers. Especially

View File

@ -18,6 +18,7 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/alerting"
"github.com/grafana/grafana/pkg/services/datasources"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
)
func init() {
@ -276,8 +277,8 @@ func (c *QueryCondition) getRequestForAlertRule(datasource *datasources.DataSour
},
},
Headers: map[string]string{
"FromAlert": "true",
"X-Cache-Skip": "true",
ngalertmodels.FromAlertHeaderName: "true",
ngalertmodels.CacheSkipHeaderName: "true",
},
Debug: debug,
}

View File

@ -223,9 +223,9 @@ func buildDatasourceHeaders(ctx EvaluationContext) map[string]string {
// Note: The spelling of this headers is intentionally degenerate from the others for compatibility reasons.
// When sent over a network, the key of this header is canonicalized to "Fromalert".
// However, some datasources still compare against the string "FromAlert".
"FromAlert": "true",
models.FromAlertHeaderName: "true",
"X-Cache-Skip": "true",
models.CacheSkipHeaderName: "true",
}
key, ok := models.RuleKeyFromContext(ctx.Ctx)

View File

@ -0,0 +1,21 @@
package models
const (
// FromAlertHeaderName name of header added to datasource query requests
// to denote request is originating from Grafana Alerting.
//
// Data sources might check this in query method as sometimes alerting
// needs special considerations.
// Several existing systems also compare against the value of this header.
// Altering this constitutes a breaking change.
//
// Note: The spelling of this headers is intentionally degenerate from the
// others for compatibility reasons. When sent over a network, the key of
// this header is canonicalized to "Fromalert".
// However, some datasources still compare against the string "FromAlert".
FromAlertHeaderName = "FromAlert"
// CacheSkipHeaderName name of header added to datasource query requests
// to denote request should not be cached.
CacheSkipHeaderName = "X-Cache-Skip"
)

View File

@ -4,8 +4,6 @@ import (
"context"
"github.com/grafana/grafana-plugin-sdk-go/backend"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
)
@ -25,19 +23,19 @@ type ClearAuthHeadersMiddleware struct {
next plugins.Client
}
func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, pCtx backend.PluginContext, req interface{}) context.Context {
func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, h backend.ForwardHTTPHeaders) {
reqCtx := contexthandler.FromContext(ctx)
// if no HTTP request context skip middleware
if req == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
return ctx
if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
return
}
list := contexthandler.AuthHTTPHeaderListFromContext(ctx)
if list != nil {
ctx = sdkhttpclient.WithContextualMiddleware(ctx, httpclientprovider.DeleteHeadersMiddleware(list.Items...))
for _, k := range list.Items {
h.DeleteHTTPHeader(k)
}
}
return ctx
}
func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
@ -45,7 +43,7 @@ func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend
return m.next.QueryData(ctx, req)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
m.clearHeaders(ctx, req)
return m.next.QueryData(ctx, req)
}
@ -55,7 +53,7 @@ func (m *ClearAuthHeadersMiddleware) CallResource(ctx context.Context, req *back
return m.next.CallResource(ctx, req, sender)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
m.clearHeaders(ctx, req)
return m.next.CallResource(ctx, req, sender)
}
@ -65,7 +63,7 @@ func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backe
return m.next.CheckHealth(ctx, req)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
m.clearHeaders(ctx, req)
return m.next.CheckHealth(ctx, req)
}

View File

@ -1,14 +1,10 @@
package clientmiddleware
import (
"bytes"
"io"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/user"
@ -42,9 +38,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
@ -55,9 +48,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
@ -68,9 +58,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
@ -92,9 +79,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
@ -105,9 +89,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
@ -118,9 +99,6 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
})
@ -155,18 +133,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
})
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
@ -177,18 +144,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader])
})
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
@ -199,18 +155,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
})
})
@ -220,12 +165,12 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
)
const customHeader = "X-Custom"
const customHeader = "x-Custom"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)
const otherHeader = "X-Other"
const otherHeader = "x-Other"
req.Header.Set(otherHeader, "test")
pluginCtx := backend.PluginContext{
@ -240,18 +185,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
})
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
@ -262,18 +196,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader])
})
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
@ -284,27 +207,8 @@ func TestClearAuthHeadersMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
})
})
})
}
var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Request: req,
Body: io.NopCloser(bytes.NewBufferString("")),
}, nil
})

View File

@ -4,9 +4,7 @@ import (
"context"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
@ -16,8 +14,8 @@ import (
const cookieHeaderName = "Cookie"
// NewCookiesMiddleware creates a new plugins.ClientMiddleware that will
// forward incoming HTTP request Cookies to outgoing plugins.Client and
// HTTP requests if the datasource has enabled forwarding of cookies (keepCookies).
// forward incoming HTTP request Cookies to outgoing plugins.Client requests
// if the datasource has enabled forwarding of cookies (keepCookies).
func NewCookiesMiddleware(skipCookiesNames []string) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &CookiesMiddleware{
@ -32,17 +30,17 @@ type CookiesMiddleware struct {
skipCookiesNames []string
}
func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) error {
reqCtx := contexthandler.FromContext(ctx)
// if request not for a datasource or no HTTP request context skip middleware
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
return ctx, nil
return nil
}
settings := pCtx.DataSourceInstanceSettings
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
if err != nil {
return ctx, err
return err
}
ds := &datasources.DataSource{
@ -54,20 +52,29 @@ func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.Plugi
proxyutil.ClearCookieHeader(reqCtx.Req, ds.AllowedCookies(), m.skipCookiesNames)
if cookieStr := reqCtx.Req.Header.Get(cookieHeaderName); cookieStr != "" {
switch t := req.(type) {
case *backend.QueryDataRequest:
cookieStr := reqCtx.Req.Header.Get(cookieHeaderName)
switch t := req.(type) {
case *backend.QueryDataRequest:
if cookieStr == "" {
delete(t.Headers, cookieHeaderName)
} else {
t.Headers[cookieHeaderName] = cookieStr
case *backend.CheckHealthRequest:
}
case *backend.CheckHealthRequest:
if cookieStr == "" {
delete(t.Headers, cookieHeaderName)
} else {
t.Headers[cookieHeaderName] = cookieStr
case *backend.CallResourceRequest:
}
case *backend.CallResourceRequest:
if cookieStr == "" {
delete(t.Headers, cookieHeaderName)
} else {
t.Headers[cookieHeaderName] = []string{cookieStr}
}
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(reqCtx.Req.Cookies(), ds.AllowedCookies(), m.skipCookiesNames))
return ctx, nil
return nil
}
func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
@ -75,12 +82,12 @@ func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDat
return m.next.QueryData(ctx, req)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.QueryData(newCtx, req)
return m.next.QueryData(ctx, req)
}
func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
@ -88,12 +95,12 @@ func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallR
return m.next.CallResource(ctx, req, sender)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return err
}
return m.next.CallResource(newCtx, req, sender)
return m.next.CallResource(ctx, req, sender)
}
func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
@ -101,12 +108,12 @@ func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckH
return m.next.CheckHealth(ctx, req)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.CheckHealth(newCtx, req)
return m.next.CheckHealth(ctx, req)
}
func (m *CookiesMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {

View File

@ -6,8 +6,6 @@ import (
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/require"
@ -54,39 +52,19 @@ func TestCookiesMiddleware(t *testing.T) {
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should not forward cookies when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
pReq := &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
}
pReq.Headers[backend.CookiesHeaderName] = []string{req.Header.Get(backend.CookiesHeaderName)}
err = cdt.Decorator.CallResource(req.Context(), pReq, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should not forward cookies when calling CheckHealth", func(t *testing.T) {
@ -98,17 +76,6 @@ func TestCookiesMiddleware(t *testing.T) {
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
})
@ -157,18 +124,6 @@ func TestCookiesMiddleware(t *testing.T) {
require.Len(t, cdt.QueryDataReq.Headers, 2)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
require.EqualValues(t, "cookie2=", cdt.QueryDataReq.Headers[cookieHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
t.Run("Should forward cookies when calling CallResource", func(t *testing.T) {
@ -182,18 +137,6 @@ func TestCookiesMiddleware(t *testing.T) {
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
require.Len(t, cdt.CallResourceReq.Headers[cookieHeaderName], 1)
require.EqualValues(t, "cookie2=", cdt.CallResourceReq.Headers[cookieHeaderName][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
t.Run("Should forward cookies when calling CheckHealth", func(t *testing.T) {
@ -206,18 +149,6 @@ func TestCookiesMiddleware(t *testing.T) {
require.Len(t, cdt.CheckHealthReq.Headers, 2)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
require.EqualValues(t, "cookie2=", cdt.CheckHealthReq.Headers[cookieHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
})
}

View File

@ -0,0 +1,108 @@
package clientmiddleware
import (
"context"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/plugins"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
)
const forwardPluginRequestHTTPHeaders = "forward-plugin-request-http-headers"
// NewHTTPClientMiddleware creates a new plugins.ClientMiddleware
// that will forward plugin request headers as outgoing HTTP headers.
func NewHTTPClientMiddleware() plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &HTTPClientMiddleware{
next: next,
}
})
}
type HTTPClientMiddleware struct {
next plugins.Client
}
func (m *HTTPClientMiddleware) applyHeaders(ctx context.Context, pReq interface{}) context.Context {
if pReq == nil {
return ctx
}
mw := httpclient.NamedMiddlewareFunc(forwardPluginRequestHTTPHeaders, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
switch t := pReq.(type) {
case *backend.QueryDataRequest:
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
req.Header.Set(ngalertmodels.FromAlertHeaderName, val)
}
case *backend.CallResourceRequest:
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
req.Header.Set(ngalertmodels.FromAlertHeaderName, val[0])
}
case *backend.CheckHealthRequest:
if val, exists := t.Headers[ngalertmodels.FromAlertHeaderName]; exists {
req.Header.Set(ngalertmodels.FromAlertHeaderName, val)
}
}
if h, ok := pReq.(backend.ForwardHTTPHeaders); ok {
for k, v := range h.GetHTTPHeaders() {
req.Header[k] = v
}
}
return next.RoundTrip(req)
})
})
return httpclient.WithContextualMiddleware(ctx, mw)
}
func (m *HTTPClientMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return m.next.QueryData(ctx, req)
}
ctx = m.applyHeaders(ctx, req)
return m.next.QueryData(ctx, req)
}
func (m *HTTPClientMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return m.next.CallResource(ctx, req, sender)
}
ctx = m.applyHeaders(ctx, req)
return m.next.CallResource(ctx, req, sender)
}
func (m *HTTPClientMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return m.next.CheckHealth(ctx, req)
}
ctx = m.applyHeaders(ctx, req)
return m.next.CheckHealth(ctx, req)
}
func (m *HTTPClientMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
return m.next.CollectMetrics(ctx, req)
}
func (m *HTTPClientMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
return m.next.SubscribeStream(ctx, req)
}
func (m *HTTPClientMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
return m.next.PublishStream(ctx, req)
}
func (m *HTTPClientMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
return m.next.RunStream(ctx, req, sender)
}

View File

@ -0,0 +1,289 @@
package clientmiddleware
import (
"bytes"
"io"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/proxyutil"
"github.com/stretchr/testify/require"
)
func TestHTTPClientMiddleware(t *testing.T) {
const otherHeader = "test"
t.Run("When no http headers in plugin request", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
t.Run("And requests are for a datasource", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
}
t.Run("Should not forward headers when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "val"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
t.Run("Should not forward headers when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"val"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "val"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
})
t.Run("And requests are for an app", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
)
pluginCtx := backend.PluginContext{
AppInstanceSettings: &backend.AppInstanceSettings{},
}
t.Run("Should not forward headers when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 0)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
t.Run("Should not forward headers when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 0)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
t.Run("Should not forward headers when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 0)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 0)
})
})
})
t.Run("When HTTP headers in plugin request", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
headers := map[string]string{
ngalertmodels.FromAlertHeaderName: "true",
backend.OAuthIdentityTokenHeaderName: "bearer token",
backend.OAuthIdentityIDTokenHeaderName: "id-token",
"http_" + proxyutil.UserHeaderName: "uname",
backend.CookiesHeaderName: "cookie1=; cookie2=; cookie3=",
otherHeader: "val",
}
crHeaders := map[string][]string{}
for k, v := range headers {
crHeaders[k] = []string{v}
}
t.Run("And requests are for a datasource", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewHTTPClientMiddleware()),
)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
}
t.Run("Should forward headers when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: headers,
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 6)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 5)
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
require.Len(t, reqClone.Cookies(), 3)
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
})
t.Run("Should forward headers when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: crHeaders,
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 6)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 5)
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
require.Len(t, reqClone.Cookies(), 3)
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
})
t.Run("Should forward headers when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: headers,
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 6)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, forwardPluginRequestHTTPHeaders, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 5)
require.Equal(t, "true", reqClone.Header.Get(ngalertmodels.FromAlertHeaderName))
require.Equal(t, "bearer token", reqClone.Header.Get(backend.OAuthIdentityTokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
require.Equal(t, "uname", reqClone.Header.Get(proxyutil.UserHeaderName))
require.Len(t, reqClone.Cookies(), 3)
require.Equal(t, "cookie1", reqClone.Cookies()[0].Name)
require.Equal(t, "cookie2", reqClone.Cookies()[1].Name)
require.Equal(t, "cookie3", reqClone.Cookies()[2].Name)
})
})
})
}
var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Request: req,
Body: io.NopCloser(bytes.NewBufferString("")),
}, nil
})

View File

@ -3,12 +3,9 @@ package clientmiddleware
import (
"context"
"fmt"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
@ -16,8 +13,8 @@ import (
)
// NewOAuthTokenMiddleware creates a new plugins.ClientMiddleware that will
// set OAuth token headers on outgoing plugins.Client and HTTP requests if
// the datasource has enabled Forward OAuth Identity (oauthPassThru).
// set OAuth token headers on outgoing plugins.Client requests if the
// datasource has enabled Forward OAuth Identity (oauthPassThru).
func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &OAuthTokenMiddleware{
@ -37,17 +34,17 @@ type OAuthTokenMiddleware struct {
next plugins.Client
}
func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) error {
reqCtx := contexthandler.FromContext(ctx)
// if request not for a datasource or no HTTP request context skip middleware
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
return ctx, nil
return nil
}
settings := pCtx.DataSourceInstanceSettings
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
if err != nil {
return ctx, err
return err
}
ds := &datasources.DataSource{
@ -84,19 +81,10 @@ func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.Plug
t.Headers[idTokenHeaderName] = []string{idTokenHeader}
}
}
httpHeaders := http.Header{}
httpHeaders.Set(tokenHeaderName, authorizationHeader)
if idTokenHeader != "" {
httpHeaders.Set(idTokenHeaderName, idTokenHeader)
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.SetHeadersMiddleware(httpHeaders))
}
}
return ctx, nil
return nil
}
func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
@ -104,12 +92,12 @@ func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.Query
return m.next.QueryData(ctx, req)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.QueryData(newCtx, req)
return m.next.QueryData(ctx, req)
}
func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
@ -117,12 +105,12 @@ func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.Ca
return m.next.CallResource(ctx, req, sender)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return err
}
return m.next.CallResource(newCtx, req, sender)
return m.next.CallResource(ctx, req, sender)
}
func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
@ -130,12 +118,12 @@ func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.Che
return m.next.CheckHealth(ctx, req)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.CheckHealth(newCtx, req)
return m.next.CheckHealth(ctx, req)
}
func (m *OAuthTokenMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {

View File

@ -6,8 +6,6 @@ import (
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
@ -49,9 +47,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not forward OAuth Identity when calling CallResource", func(t *testing.T) {
@ -63,9 +58,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
@ -77,9 +69,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
@ -125,19 +114,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName])
require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) {
@ -153,19 +129,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0])
require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1)
require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
@ -179,19 +142,6 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName])
require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
})
}

View File

@ -2,19 +2,15 @@ package clientmiddleware
import (
"context"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/util/proxyutil"
)
// NewUserHeaderMiddleware creates a new plugins.ClientMiddleware that will
// populate the X-Grafana-User header on outgoing plugins.Client and HTTP
// requests.
// populate the X-Grafana-User header on outgoing plugins.Client requests.
func NewUserHeaderMiddleware() plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &UserHeaderMiddleware{
@ -27,33 +23,17 @@ type UserHeaderMiddleware struct {
next plugins.Client
}
func (m *UserHeaderMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, h backend.ForwardHTTPHeaders) context.Context {
func (m *UserHeaderMiddleware) applyUserHeader(ctx context.Context, h backend.ForwardHTTPHeaders) {
reqCtx := contexthandler.FromContext(ctx)
// if no HTTP request context skip middleware
if h == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
return ctx
return
}
h.DeleteHTTPHeader(proxyutil.UserHeaderName)
if !reqCtx.IsAnonymous {
h.SetHTTPHeader(proxyutil.UserHeaderName, reqCtx.Login)
}
middlewares := []sdkhttpclient.Middleware{}
if !reqCtx.IsAnonymous {
httpHeaders := http.Header{
proxyutil.UserHeaderName: []string{reqCtx.Login},
}
middlewares = append(middlewares, httpclientprovider.SetHeadersMiddleware(httpHeaders))
} else {
middlewares = append(middlewares, httpclientprovider.DeleteHeadersMiddleware(proxyutil.UserHeaderName))
}
ctx = sdkhttpclient.WithContextualMiddleware(ctx, middlewares...)
return ctx
}
func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
@ -61,7 +41,7 @@ func (m *UserHeaderMiddleware) QueryData(ctx context.Context, req *backend.Query
return m.next.QueryData(ctx, req)
}
ctx = m.applyToken(ctx, req.PluginContext, req)
m.applyUserHeader(ctx, req)
return m.next.QueryData(ctx, req)
}
@ -71,7 +51,7 @@ func (m *UserHeaderMiddleware) CallResource(ctx context.Context, req *backend.Ca
return m.next.CallResource(ctx, req, sender)
}
ctx = m.applyToken(ctx, req.PluginContext, req)
m.applyUserHeader(ctx, req)
return m.next.CallResource(ctx, req, sender)
}
@ -81,7 +61,7 @@ func (m *UserHeaderMiddleware) CheckHealth(ctx context.Context, req *backend.Che
return m.next.CheckHealth(ctx, req)
}
ctx = m.applyToken(ctx, req.PluginContext, req)
m.applyUserHeader(ctx, req)
return m.next.CheckHealth(ctx, req)
}

View File

@ -5,8 +5,6 @@ import (
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/proxyutil"
@ -39,10 +37,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Empty(t, cdt.QueryDataReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should not forward user header when calling CallResource", func(t *testing.T) {
@ -53,10 +47,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Empty(t, cdt.CallResourceReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) {
@ -67,10 +57,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Empty(t, cdt.CheckHealthReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
})
@ -95,10 +81,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Empty(t, cdt.QueryDataReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should not forward user header when calling CallResource", func(t *testing.T) {
@ -109,10 +91,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Empty(t, cdt.CallResourceReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should not forward user header when calling CheckHealth", func(t *testing.T) {
@ -123,10 +101,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Empty(t, cdt.CheckHealthReq.Headers)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
})
})
@ -156,10 +130,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should forward user header when calling CallResource", func(t *testing.T) {
@ -171,10 +141,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) {
@ -186,10 +152,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
})
@ -214,10 +176,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "admin", cdt.QueryDataReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should forward user header when calling CallResource", func(t *testing.T) {
@ -229,10 +187,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "admin", cdt.CallResourceReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
t.Run("Should forward user header when calling CheckHealth", func(t *testing.T) {
@ -244,10 +198,6 @@ func TestUserHeaderMiddleware(t *testing.T) {
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName))
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
})
})
})

View File

@ -81,5 +81,7 @@ func CreateMiddlewares(cfg *setting.Cfg, oAuthTokenService oauthtoken.OAuthToken
middlewares = append(middlewares, clientmiddleware.NewUserHeaderMiddleware())
}
middlewares = append(middlewares, clientmiddleware.NewHTTPClientMiddleware())
return middlewares
}

View File

@ -26,92 +26,226 @@ import (
"golang.org/x/oauth2"
)
const loginCookieName = "grafana_session"
func TestIntegrationBackendPlugins(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
regularQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
return metricRequestWithQueries(t, fmt.Sprintf(`{
"datasource": {
"uid": "%s"
}
}`, tsCtx.uid))
oauthToken := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
RefreshToken: "refresh-token",
Expiry: time.Now().UTC().Add(24 * time.Hour),
}
oauthToken = oauthToken.WithExtra(map[string]interface{}{"id_token": "id-token"})
expressionQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
newTestScenario(t, "Datasource with no custom HTTP settings",
options(
withIncomingRequest(func(req *http.Request) {
req.Header.Set("X-Custom", "custom")
req.AddCookie(&http.Cookie{Name: "cookie1"})
req.AddCookie(&http.Cookie{Name: "cookie2"})
req.AddCookie(&http.Cookie{Name: "cookie3"})
req.AddCookie(&http.Cookie{Name: loginCookieName})
}),
),
func(t *testing.T, tsCtx *testScenarioContext) {
verify := func(h backend.ForwardHTTPHeaders) {
require.NotNil(t, h)
require.Empty(t, h.GetHTTPHeader(backend.CookiesHeaderName))
require.Empty(t, h.GetHTTPHeader("Authorization"))
return metricRequestWithQueries(t, fmt.Sprintf(`{
"refId": "A",
"datasource": {
"uid": "%s",
"type": "%s"
require.NotNil(t, tsCtx.outgoingRequest)
require.NotEmpty(t, tsCtx.outgoingRequest.Header)
}
}`, tsCtx.uid, tsCtx.testPluginID), `{
"refId": "B",
"datasource": {
"type": "__expr__",
"uid": "__expr__",
"name": "Expression"
},
"type": "math",
"expression": "$A - 50"
}`)
}
newTestScenario(t, "When oauth token not available", func(t *testing.T, tsCtx *testScenarioContext) {
tsCtx.testEnv.OAuthTokenService.Token = nil
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
verify(pReq)
})
tsCtx.runCheckHealthTest(t)
tsCtx.runCallResourceTest(t)
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
verify(pReq)
require.Equal(t, "custom", pReq.GetHTTPHeader("X-Custom"))
require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom"))
})
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
verifyQueryData := func(pReq *backend.QueryDataRequest) {
verify(pReq)
}
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
})
})
})
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
})
})
newTestScenario(t, "Datasource with most HTTP settings set except oauthPassThru and oauth token available",
options(
withIncomingRequest(func(req *http.Request) {
req.AddCookie(&http.Cookie{Name: "cookie1"})
req.AddCookie(&http.Cookie{Name: "cookie2"})
req.AddCookie(&http.Cookie{Name: "cookie3"})
req.AddCookie(&http.Cookie{Name: loginCookieName})
}),
withOAuthToken(oauthToken),
withDsBasicAuth("basicAuthUser", "basicAuthPassword"),
withDsCustomHeader(map[string]string{"X-CUSTOM-HEADER": "custom-header-value"}),
withDsCookieForwarding([]string{"cookie1", "cookie3", loginCookieName}),
),
func(t *testing.T, tsCtx *testScenarioContext) {
verify := func(h backend.ForwardHTTPHeaders) {
require.NotNil(t, h)
require.Equal(t, "cookie1=; cookie3=", h.GetHTTPHeader(backend.CookiesHeaderName))
require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName))
require.Empty(t, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName))
newTestScenario(t, "When oauth token available", func(t *testing.T, tsCtx *testScenarioContext) {
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
RefreshToken: "refresh-token",
Expiry: time.Now().UTC().Add(24 * time.Hour),
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tsCtx.testEnv.OAuthTokenService.Token = token
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get(backend.CookiesHeaderName))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
tsCtx.runCheckHealthTest(t)
tsCtx.runCallResourceTest(t)
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
}
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
verify(pReq)
})
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
verify(pReq)
})
verifyQueryData := func(pReq *backend.QueryDataRequest) {
verify(pReq)
}
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
})
})
})
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
newTestScenario(t, "Datasource with oauthPassThru and basic auth configured and oauth token available",
options(
withOAuthToken(oauthToken),
withDsOAuthForwarding(),
withDsBasicAuth("basicAuthUser", "basicAuthPassword"),
),
func(t *testing.T, tsCtx *testScenarioContext) {
verify := func(h backend.ForwardHTTPHeaders) {
require.NotNil(t, h)
expectedAuthHeader := fmt.Sprintf("Bearer %s", oauthToken.AccessToken)
expectedTokenHeader := oauthToken.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, h.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName))
require.Equal(t, expectedTokenHeader, h.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName))
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityTokenHeaderName))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get(backend.OAuthIdentityIDTokenHeaderName))
}
tsCtx.runCheckHealthTest(t, func(pReq *backend.CheckHealthRequest) {
verify(pReq)
})
tsCtx.runCallResourceTest(t, func(pReq *backend.CallResourceRequest) {
verify(pReq)
})
verifyQueryData := func(pReq *backend.QueryDataRequest) {
verify(pReq)
}
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createRegularQuery(t, tsCtx), verifyQueryData)
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, createExpressionQuery(t, tsCtx), verifyQueryData)
})
})
})
})
}
type testScenarioContext struct {
testPluginID string
uid string
grafanaListeningAddr string
testEnv *server.TestEnv
outgoingServer *httptest.Server
outgoingRequest *http.Request
backendTestPlugin *testPlugin
rt http.RoundTripper
testPluginID string
uid string
grafanaListeningAddr string
testEnv *server.TestEnv
outgoingServer *httptest.Server
outgoingRequest *http.Request
backendTestPlugin *testPlugin
rt http.RoundTripper
modifyIncomingRequest func(req *http.Request)
}
func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx *testScenarioContext)) {
type testScenarioInput struct {
ds *datasources.AddDataSourceCommand
token *oauth2.Token
modifyIncomingRequest func(req *http.Request)
}
type testScenarioOption func(*testScenarioInput)
func options(opts ...testScenarioOption) []testScenarioOption {
return opts
}
func withIncomingRequest(cb func(req *http.Request)) testScenarioOption {
return func(in *testScenarioInput) {
in.modifyIncomingRequest = cb
}
}
func withOAuthToken(token *oauth2.Token) testScenarioOption {
return func(in *testScenarioInput) {
in.token = token
}
}
func withDsBasicAuth(username, password string) testScenarioOption {
return func(in *testScenarioInput) {
in.ds.BasicAuth = true
in.ds.BasicAuthUser = username
in.ds.SecureJsonData["basicAuthPassword"] = password
}
}
func withDsCustomHeader(headers map[string]string) testScenarioOption {
return func(in *testScenarioInput) {
index := 1
for k, v := range headers {
in.ds.JsonData.Set(fmt.Sprintf("httpHeaderName%d", index), k)
in.ds.SecureJsonData[fmt.Sprintf("httpHeaderValue%d", index)] = v
index++
}
}
}
func withDsOAuthForwarding() testScenarioOption {
return func(in *testScenarioInput) {
in.ds.JsonData.Set("oauthPassThru", true)
}
}
func withDsCookieForwarding(names []string) testScenarioOption {
return func(in *testScenarioInput) {
in.ds.JsonData.Set("keepCookies", names)
}
}
func newTestScenario(t *testing.T, name string, opts []testScenarioOption, callback func(t *testing.T, ctx *testScenarioContext)) {
tsCtx := testScenarioContext{
testPluginID: "test-plugin",
}
@ -123,6 +257,7 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
tsCtx.grafanaListeningAddr = grafanaListeningAddr
testEnv.SQLStore.Cfg.LoginCookieName = loginCookieName
tsCtx.testEnv = testEnv
ctx := context.Background()
@ -143,29 +278,30 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
err := testEnv.PluginRegistry.Add(ctx, testPlugin)
require.NoError(t, err)
jsonData := simplejson.NewFromAny(map[string]interface{}{
"httpHeaderName1": "X-CUSTOM-HEADER",
"oauthPassThru": true,
"keepCookies": []string{"cookie1", "cookie3", "grafana_session"},
})
secureJSONData := map[string]string{
"basicAuthPassword": "basicAuthPassword",
"httpHeaderValue1": "custom-header-value",
}
jsonData := simplejson.New()
secureJSONData := map[string]string{}
tsCtx.uid = "test-plugin"
err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, &datasources.AddDataSourceCommand{
cmd := &datasources.AddDataSourceCommand{
OrgId: 1,
Access: datasources.DS_ACCESS_PROXY,
Name: "TestPlugin",
Type: tsCtx.testPluginID,
Uid: tsCtx.uid,
Url: tsCtx.outgoingServer.URL,
BasicAuth: true,
BasicAuthUser: "basicAuthUser",
JsonData: jsonData,
SecureJsonData: secureJSONData,
})
}
in := &testScenarioInput{ds: cmd}
for _, opt := range opts {
opt(in)
}
tsCtx.modifyIncomingRequest = in.modifyIncomingRequest
tsCtx.testEnv.OAuthTokenService.Token = in.token
err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, cmd)
require.NoError(t, err)
getDataSourceQuery := &datasources.GetDataSourceQuery{
@ -185,17 +321,42 @@ func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx
})
}
func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest) {
t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.QueryDataRequest
func createRegularQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
return metricRequestWithQueries(t, fmt.Sprintf(`{
"datasource": {
"uid": "%s"
}
}`, tsCtx.uid))
}
func createExpressionQuery(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
return metricRequestWithQueries(t, fmt.Sprintf(`{
"refId": "A",
"datasource": {
"uid": "%s",
"type": "%s"
}
}`, tsCtx.uid, tsCtx.testPluginID), `{
"refId": "B",
"datasource": {
"type": "__expr__",
"uid": "__expr__",
"name": "Expression"
},
"type": "math",
"expression": "$A - 50"
}`)
}
func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest, callback func(req *backend.QueryDataRequest)) {
t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) {
var received *backend.QueryDataRequest
tsCtx.backendTestPlugin.QueryDataHandler = backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
received = &struct {
ctx context.Context
req *backend.QueryDataRequest
}{ctx, req}
received = req
c := http.Client{
Transport: tsCtx.rt,
@ -227,18 +388,11 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR
req, err := http.NewRequest(http.MethodPost, u, buf1)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
if tsCtx.modifyIncomingRequest != nil {
tsCtx.modifyIncomingRequest(req)
}
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
@ -253,51 +407,18 @@ func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricR
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
callback(received)
})
}
func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T, callback func(req *backend.CheckHealthRequest)) {
t.Run("When calling /api/datasources/uid/:uid/health should set expected headers on outgoing CheckHealth and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.CheckHealthRequest
}
var received *backend.CheckHealthRequest
tsCtx.backendTestPlugin.CheckHealthHandler = backend.CheckHealthHandlerFunc(func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
received = &struct {
ctx context.Context
req *backend.CheckHealthRequest
}{ctx, req}
received = req
c := http.Client{
Transport: tsCtx.rt,
@ -328,18 +449,11 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, u, nil)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
if tsCtx.modifyIncomingRequest != nil {
tsCtx.modifyIncomingRequest(req)
}
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
@ -354,51 +468,18 @@ func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
callback(received)
})
}
func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T, callback func(req *backend.CallResourceRequest)) {
t.Run("When calling /api/datasources/uid/:uid/resources should set expected headers on outgoing CallResource and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.CallResourceRequest
}
var received *backend.CallResourceRequest
tsCtx.backendTestPlugin.CallResourceHandler = backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
received = &struct {
ctx context.Context
req *backend.CallResourceRequest
}{ctx, req}
received = req
c := http.Client{
Transport: tsCtx.rt,
@ -450,19 +531,11 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
req.Header.Set("Connection", "X-Some-Conn-Header")
req.Header.Set("X-Some-Conn-Header", "should be deleted")
req.Header.Set("Proxy-Connection", "should be deleted")
req.Header.Set("X-Custom", "custom")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36")
if tsCtx.modifyIncomingRequest != nil {
tsCtx.modifyIncomingRequest(req)
}
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
@ -484,45 +557,18 @@ func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
require.Empty(t, resp.Header.Get("Set-Cookie"))
require.Equal(t, "should not be deleted", resp.Header.Get("X-Custom"))
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"][0])
require.Empty(t, received.req.Headers["Connection"])
require.Empty(t, received.req.Headers["X-Some-Conn-Header"])
require.Empty(t, received.req.Headers["Proxy-Connection"])
require.Equal(t, "custom", received.req.Headers["X-Custom"][0])
require.Empty(t, received.Headers["Connection"])
require.Empty(t, received.Headers["X-Some-Conn-Header"])
require.Empty(t, received.Headers["Proxy-Connection"])
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"][0])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"][0])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Empty(t, tsCtx.outgoingRequest.Header.Get("Connection"))
require.Empty(t, tsCtx.outgoingRequest.Header.Get("X-Some-Conn-Header"))
require.Empty(t, tsCtx.outgoingRequest.Header.Get("Proxy-Connection"))
require.Equal(t, "custom", tsCtx.outgoingRequest.Header.Get("X-Custom"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
require.NotEmpty(t, tsCtx.outgoingRequest.Header.Get("Accept-Encoding"))
require.Equal(t, fmt.Sprintf("Grafana/%s", tsCtx.testEnv.SQLStore.Cfg.BuildVersion), tsCtx.outgoingRequest.Header.Get("User-Agent"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
callback(received)
})
}

View File

@ -23,6 +23,7 @@ import (
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/featuremgmt"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/clients"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/models"
@ -162,7 +163,7 @@ func (e *cloudWatchExecutor) QueryData(ctx context.Context, req *backend.QueryDa
if err != nil {
return nil, err
}
_, fromAlert := req.Headers["FromAlert"]
_, fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName]
isLogAlertQuery := fromAlert && model.QueryMode == logsQueryMode
if isLogAlertQuery {

View File

@ -20,6 +20,7 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/services/featuremgmt"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/mocks"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/models"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch/utils"
@ -212,7 +213,7 @@ func Test_executeLogAlertQuery(t *testing.T) {
executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures())
_, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{
Headers: map[string]string{"FromAlert": "some value"},
Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"},
PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}},
Queries: []backend.DataQuery{
{
@ -238,7 +239,7 @@ func Test_executeLogAlertQuery(t *testing.T) {
executor := newExecutor(im, newTestConfig(), &sess, featuremgmt.WithFeatures())
_, err := executor.QueryData(context.Background(), &backend.QueryDataRequest{
Headers: map[string]string{"FromAlert": "some value"},
Headers: map[string]string{ngalertmodels.FromAlertHeaderName: "some value"},
PluginContext: backend.PluginContext{DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}},
Queries: []backend.DataQuery{
{

View File

@ -18,10 +18,9 @@ import (
)
type LokiAPI struct {
client *http.Client
url string
log log.Logger
headers map[string]string
client *http.Client
url string
log log.Logger
}
type RawLokiResponse struct {
@ -29,17 +28,11 @@ type RawLokiResponse struct {
Encoding string
}
func newLokiAPI(client *http.Client, url string, log log.Logger, headers map[string]string) *LokiAPI {
return &LokiAPI{client: client, url: url, log: log, headers: headers}
func newLokiAPI(client *http.Client, url string, log log.Logger) *LokiAPI {
return &LokiAPI{client: client, url: url, log: log}
}
func addHeaders(req *http.Request, headers map[string]string) {
for name, value := range headers {
req.Header.Set(name, value)
}
}
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, headers map[string]string) (*http.Request, error) {
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*http.Request, error) {
qs := url.Values{}
qs.Set("query", query.Expr)
@ -92,8 +85,6 @@ func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, hea
return nil, err
}
addHeaders(req, headers)
if query.VolumeQuery {
req.Header.Set("X-Query-Tags", "Source=logvolhist")
}
@ -145,7 +136,7 @@ func makeLokiError(body io.ReadCloser) error {
}
func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames, error) {
req, err := makeDataRequest(ctx, api.url, query, api.headers)
req, err := makeDataRequest(ctx, api.url, query)
if err != nil {
return nil, err
}
@ -183,7 +174,7 @@ func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (data.Frames
return res.Frames, nil
}
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string, headers map[string]string) (*http.Request, error) {
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string) (*http.Request, error) {
lokiUrl, err := url.Parse(lokiDsUrl)
if err != nil {
return nil, err
@ -204,13 +195,11 @@ func makeRawRequest(ctx context.Context, lokiDsUrl string, resourcePath string,
return nil, err
}
addHeaders(req, headers)
return req, nil
}
func (api *LokiAPI) RawQuery(ctx context.Context, resourcePath string) (RawLokiResponse, error) {
req, err := makeRawRequest(ctx, api.url, resourcePath, api.headers)
req, err := makeRawRequest(ctx, api.url, resourcePath)
if err != nil {
return RawLokiResponse{}, err
}

View File

@ -64,7 +64,7 @@ func makeMockedAPIWithUrl(url string, statusCode int, contentType string, respon
Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback},
}
return newLokiAPI(&client, url, log.New("test"), nil)
return newLokiAPI(&client, url, log.New("test"))
}
func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType string, responseBytes []byte, requestCallback mockRequestCallback) *LokiAPI {
@ -72,5 +72,5 @@ func makeCompressedMockedAPIWithUrl(url string, statusCode int, contentType stri
Transport: &mockedCompressedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback},
}
return newLokiAPI(&client, url, log.New("test"), nil)
return newLokiAPI(&client, url, log.New("test"))
}

View File

@ -1,201 +0,0 @@
package loki
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
)
type mockedRoundTripperForOauth struct {
requestCallback func(req *http.Request)
body []byte
}
func (mockedRT *mockedRoundTripperForOauth) RoundTrip(req *http.Request) (*http.Response, error) {
mockedRT.requestCallback(req)
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(mockedRT.body)),
}, nil
}
type mockedCallResourceResponseSenderForOauth struct {
Response *backend.CallResourceResponse
}
func (s *mockedCallResourceResponseSenderForOauth) Send(resp *backend.CallResourceResponse) error {
s.Response = resp
return nil
}
func makeMockedDsInfoForOauth(body []byte, requestCallback func(req *http.Request)) datasourceInfo {
client := http.Client{
Transport: &mockedRoundTripperForOauth{requestCallback: requestCallback, body: body},
}
return datasourceInfo{
HTTPClient: &client,
}
}
func TestOauthForwardIdentity(t *testing.T) {
tt := []struct {
name string
auth bool
cookie bool
}{
{name: "when auth headers exist => add auth headers", auth: true, cookie: false},
{name: "when cookie header exists => add cookie header", auth: false, cookie: true},
{name: "when cookie&auth headers exist => add cookie&auth headers", auth: true, cookie: true},
{name: "when no header exists => do not add headers", auth: false, cookie: false},
}
authName := "Authorization"
authValue := "auth"
cookieName := "Cookie"
cookieValue := "a=1"
idTokenName := "X-ID-Token"
idTokenValue := "idtoken"
for _, test := range tt {
t.Run("QueryData: "+test.name, func(t *testing.T) {
response := []byte(`
{
"status": "success",
"data": {
"resultType": "streams",
"result": [
{
"stream": {},
"values": [
["1", "line1"]
]
}
]
}
}
`)
clientUsed := false
dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) {
clientUsed = true
// we need to check for "header does not exist",
// and the only way i can find is to get the values
// as an array
authValues := req.Header.Values(authName)
cookieValues := req.Header.Values(cookieName)
idTokenValues := req.Header.Values(idTokenName)
if test.auth {
require.Equal(t, []string{authValue}, authValues)
require.Equal(t, []string{idTokenValue}, idTokenValues)
} else {
require.Len(t, authValues, 0)
require.Len(t, idTokenValues, 0)
}
if test.cookie {
require.Equal(t, []string{cookieValue}, cookieValues)
} else {
require.Len(t, cookieValues, 0)
}
})
req := backend.QueryDataRequest{
Headers: map[string]string{},
Queries: []backend.DataQuery{
{
RefID: "A",
JSON: []byte("{}"),
},
},
}
if test.auth {
req.Headers[authName] = authValue
req.Headers[idTokenName] = idTokenValue
}
if test.cookie {
req.Headers[cookieName] = cookieValue
}
tracer := tracing.InitializeTracerForTest()
data, err := queryData(context.Background(), &req, &dsInfo, tracer)
// we do a basic check that the result is OK
require.NoError(t, err)
require.Len(t, data.Responses, 1)
res := data.Responses["A"]
require.NoError(t, res.Error)
require.Len(t, res.Frames, 1)
require.Equal(t, "line1", res.Frames[0].Fields[2].At(0))
// we need to be sure the client-callback was triggered
require.True(t, clientUsed)
})
}
for _, test := range tt {
t.Run("CallResource: "+test.name, func(t *testing.T) {
response := []byte("mocked resource response")
clientUsed := false
dsInfo := makeMockedDsInfoForOauth(response, func(req *http.Request) {
clientUsed = true
authValues := req.Header.Values(authName)
cookieValues := req.Header.Values(cookieName)
idTokenValues := req.Header.Values(idTokenName)
// we need to check for "header does not exist",
// and the only way i can find is to get the values
// as an array
if test.auth {
require.Equal(t, []string{authValue}, authValues)
require.Equal(t, []string{idTokenValue}, idTokenValues)
} else {
require.Len(t, authValues, 0)
require.Len(t, idTokenValues, 0)
}
if test.cookie {
require.Equal(t, []string{cookieValue}, cookieValues)
} else {
require.Len(t, cookieValues, 0)
}
})
req := backend.CallResourceRequest{
Headers: map[string][]string{},
Method: "GET",
URL: "labels?",
}
if test.auth {
req.Headers[authName] = []string{authValue}
req.Headers[idTokenName] = []string{idTokenValue}
}
if test.cookie {
req.Headers[cookieName] = []string{cookieValue}
}
sender := &mockedCallResourceResponseSenderForOauth{}
err := callResource(context.Background(), &req, sender, &dsInfo, log.New("testlog"))
// we do a basic check that the result is OK
require.NoError(t, err)
sent := sender.Response
require.NotNil(t, sent)
require.Equal(t, http.StatusOK, sent.Status)
require.Equal(t, response, sent.Body)
// we need to be sure the client-callback was triggered
require.True(t, clientUsed)
})
}
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/textproto"
"regexp"
"strings"
"sync"
@ -96,22 +95,6 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
}
}
// in the CallResource API, request-headers are in a map where the value is an array-of-strings,
// so we need a helper function that can extract a single string-value from an array-of-strings.
// i only deal with two cases:
// - zero-length array
// - first-item of the array
// i do not handle the case where there are multiple items in the array, i do not know
// if that can even happen ever, for the headers that we are interested in.
func arrayHeaderFirstValue(values []string) string {
if len(values) == 0 {
return ""
}
// NOTE: we assume there never is a second item in the http-header-values-array
return values[0]
}
func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
dsInfo, err := s.getDSInfo(req.PluginContext)
if err != nil {
@ -120,30 +103,6 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq
return callResource(ctx, req, sender, dsInfo, logger.FromContext(ctx))
}
func getHeadersForCallResource(headers map[string][]string) map[string]string {
data := make(map[string]string)
for k, values := range headers {
k = textproto.CanonicalMIMEHeaderKey(k)
firstValue := arrayHeaderFirstValue(values)
if firstValue == "" {
continue
}
switch k {
case "Authorization":
data["Authorization"] = firstValue
case "X-Id-Token":
data["X-ID-Token"] = firstValue
case "Cookie":
data["Cookie"] = firstValue
case "Accept-Encoding":
data["Accept-Encoding"] = firstValue
}
}
return data
}
func callResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, dsInfo *datasourceInfo, plog log.Logger) error {
url := req.URL
@ -158,7 +117,7 @@ func callResource(ctx context.Context, req *backend.CallResourceRequest, sender
}
lokiURL := fmt.Sprintf("/loki/api/v1/%s", url)
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getHeadersForCallResource(req.Headers))
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog)
encodedBytes, err := api.RawQuery(ctx, lokiURL)
if err != nil {
@ -191,7 +150,7 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
func queryData(ctx context.Context, req *backend.QueryDataRequest, dsInfo *datasourceInfo, tracer tracing.Tracer) (*backend.QueryDataResponse, error) {
result := backend.NewQueryDataResponse()
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx), req.Headers)
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, logger.FromContext(ctx))
queries, err := parseQuery(req)
if err != nil {

View File

@ -1,74 +0,0 @@
package loki
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetHeadersForCallResource(t *testing.T) {
const idTokn1 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
const idTokn2 = "eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU"
const authTokn1 = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
const authTokn2 = "Bearer eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE2Njg2MjExODQsImlhdCI6MTY2ODYyMTE4NH0.bg0Y0S245DeANhNnnLBCfGYBseTld29O0xynhQwZZlU"
testCases := map[string]struct {
headers map[string][]string
expectedHeaders map[string]string
}{
"Headers with empty value": {
headers: map[string][]string{
"X-Grafana-Org-Id": {"1"},
"Cookie": {""},
"X-Id-Token": {""},
"Accept-Encoding": {""},
"Authorization": {""},
},
expectedHeaders: map[string]string{},
},
"Headers with multiple values": {
headers: map[string][]string{
"Authorization": {authTokn1, authTokn2},
"Cookie": {"a=1"},
"X-Grafana-Org-Id": {"1"},
"Accept-Encoding": {"gzip", "compress"},
"X-Id-Token": {idTokn1, idTokn2},
},
expectedHeaders: map[string]string{
"Authorization": authTokn1,
"Cookie": "a=1",
"Accept-Encoding": "gzip",
"X-ID-Token": idTokn1,
},
},
"Headers with single value": {
headers: map[string][]string{
"Authorization": {authTokn1},
"X-Grafana-Org-Id": {"1"},
"Cookie": {"a=1"},
"Accept-Encoding": {"gzip"},
"X-Id-Token": {idTokn1},
},
expectedHeaders: map[string]string{
"Authorization": authTokn1,
"Cookie": "a=1",
"Accept-Encoding": "gzip",
"X-ID-Token": idTokn1,
},
},
"Non Canonical 'X-Id-Token' header key": {
headers: map[string][]string{
"X-ID-TOKEN": {idTokn1},
},
expectedHeaders: map[string]string{
"X-ID-Token": idTokn1,
},
},
}
for name, test := range testCases {
t.Run(name, func(t *testing.T) {
headers := getHeadersForCallResource(test.headers)
assert.Equal(t, test.expectedHeaders, headers)
})
}
}

View File

@ -32,7 +32,7 @@ func NewClient(d doer, method, baseUrl string) *Client {
return &Client{doer: d, method: method, baseUrl: baseUrl}
}
func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
func (c *Client) QueryRange(ctx context.Context, q *models.Query) (*http.Response, error) {
tr := q.TimeRange()
qv := map[string]string{
"query": q.Expr,
@ -41,7 +41,7 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H
"step": strconv.FormatFloat(tr.Step.Seconds(), 'f', -1, 64),
}
req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv, headers)
req, err := c.createQueryRequest(ctx, "api/v1/query_range", qv)
if err != nil {
return nil, err
}
@ -49,14 +49,14 @@ func (c *Client) QueryRange(ctx context.Context, q *models.Query, headers http.H
return c.doer.Do(req)
}
func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
func (c *Client) QueryInstant(ctx context.Context, q *models.Query) (*http.Response, error) {
qv := map[string]string{"query": q.Expr}
tr := q.TimeRange()
if !tr.End.IsZero() {
qv["time"] = formatTime(tr.End)
}
req, err := c.createQueryRequest(ctx, "api/v1/query", qv, headers)
req, err := c.createQueryRequest(ctx, "api/v1/query", qv)
if err != nil {
return nil, err
}
@ -64,7 +64,7 @@ func (c *Client) QueryInstant(ctx context.Context, q *models.Query, headers http
return c.doer.Do(req)
}
func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers http.Header) (*http.Response, error) {
func (c *Client) QueryExemplars(ctx context.Context, q *models.Query) (*http.Response, error) {
tr := q.TimeRange()
qv := map[string]string{
"query": q.Expr,
@ -72,7 +72,7 @@ func (c *Client) QueryExemplars(ctx context.Context, q *models.Query, headers ht
"end": formatTime(tr.End),
}
req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv, headers)
req, err := c.createQueryRequest(ctx, "api/v1/query_exemplars", qv)
if err != nil {
return nil, err
}
@ -95,7 +95,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq
// We use method from the request, as for resources front end may do a fallback to GET if POST does not work
// nad we want to respect that.
httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body), req.Headers)
httpRequest, err := createRequest(ctx, req.Method, u, bytes.NewReader(req.Body))
if err != nil {
return nil, err
}
@ -103,7 +103,7 @@ func (c *Client) QueryResource(ctx context.Context, req *backend.CallResourceReq
return c.doer.Do(httpRequest)
}
func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string, headers http.Header) (*http.Request, error) {
func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map[string]string) (*http.Request, error) {
if strings.ToUpper(c.method) == http.MethodPost {
u, err := c.createUrl(endpoint, nil)
if err != nil {
@ -115,7 +115,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map
v.Set(key, val)
}
return createRequest(ctx, c.method, u, strings.NewReader(v.Encode()), headers)
return createRequest(ctx, c.method, u, strings.NewReader(v.Encode()))
}
u, err := c.createUrl(endpoint, qv)
@ -123,7 +123,7 @@ func (c *Client) createQueryRequest(ctx context.Context, endpoint string, qv map
return nil, err
}
return createRequest(ctx, c.method, u, http.NoBody, headers)
return createRequest(ctx, c.method, u, http.NoBody)
}
func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, error) {
@ -148,15 +148,12 @@ func (c *Client) createUrl(endpoint string, qs map[string]string) (*url.URL, err
return finalUrl, nil
}
func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader, header http.Header) (*http.Request, error) {
func createRequest(ctx context.Context, method string, u *url.URL, bodyReader io.Reader) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, method, u.String(), bodyReader)
if err != nil {
return nil, err
}
// request.Header is created empty from NewRequestWithContext so we can just replace it
if header != nil {
request.Header = header
}
if strings.ToUpper(method) == http.MethodPost {
// This may not be true but right now we don't have more information here and seems like we send just this type
// of encoding right now if it is a POST

View File

@ -37,7 +37,6 @@ func TestClient(t *testing.T) {
Path: "/api/v1/series",
Method: http.MethodPost,
URL: "/api/v1/series",
Headers: nil,
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
}
res, err := client.QueryResource(context.Background(), req)
@ -63,7 +62,6 @@ func TestClient(t *testing.T) {
Path: "/api/v1/series",
Method: http.MethodGet,
URL: "api/v1/series?match%5B%5D=ALERTS&start=1655272558&end=1655294158",
Headers: nil,
}
res, err := client.QueryResource(context.Background(), req)
defer func() {
@ -95,7 +93,7 @@ func TestClient(t *testing.T) {
RangeQuery: true,
Step: 1 * time.Second,
}
res, err := client.QueryRange(context.Background(), req, nil)
res, err := client.QueryRange(context.Background(), req)
defer func() {
if res != nil && res.Body != nil {
if err := res.Body.Close(); err != nil {
@ -122,7 +120,7 @@ func TestClient(t *testing.T) {
RangeQuery: true,
Step: 1 * time.Second,
}
res, err := client.QueryRange(context.Background(), req, nil)
res, err := client.QueryRange(context.Background(), req)
defer func() {
if res != nil && res.Body != nil {
if err := res.Body.Close(); err != nil {

View File

@ -85,11 +85,7 @@ func TestService(t *testing.T) {
Path: "/api/v1/series",
Method: http.MethodPost,
URL: "/api/v1/series",
// This header should be passed on to the resource request
Headers: map[string][]string{
"foo": {"bar"},
},
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
Body: []byte("match%5B%5D: ALERTS\nstart: 1655271408\nend: 1655293008"),
}
sender := &fakeSender{}
@ -100,7 +96,6 @@ func TestService(t *testing.T) {
http.Header{
"Content-Type": {"application/x-www-form-urlencoded"},
"Idempotency-Key": []string(nil),
"foo": {"bar"},
},
httpProvider.Roundtripper.Req.Header)
require.Equal(t, http.MethodPost, httpProvider.Roundtripper.Req.Method)

View File

@ -14,6 +14,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/featuremgmt"
ngalertmodels "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/tsdb/intervalv2"
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
"github.com/grafana/grafana/pkg/tsdb/prometheus/models"
@ -79,7 +80,7 @@ func New(
}
func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
fromAlert := req.Headers["FromAlert"] == "true"
fromAlert := req.Headers[ngalertmodels.FromAlertHeaderName] == "true"
result := backend.QueryDataResponse{
Responses: backend.Responses{},
}
@ -89,7 +90,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest)
if err != nil {
return &result, err
}
r, err := s.fetch(ctx, s.client, query, req.Headers)
r, err := s.fetch(ctx, s.client, query)
if err != nil {
return &result, err
}
@ -103,7 +104,7 @@ func (s *QueryData) Execute(ctx context.Context, req *backend.QueryDataRequest)
return &result, nil
}
func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.Query) (*backend.DataResponse, error) {
traceCtx, end := s.trace(ctx, q)
defer end()
@ -116,7 +117,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
}
if q.InstantQuery {
res, err := s.instantQuery(traceCtx, client, q, headers)
res, err := s.instantQuery(traceCtx, client, q)
if err != nil {
return nil, err
}
@ -125,7 +126,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
}
if q.RangeQuery {
res, err := s.rangeQuery(traceCtx, client, q, headers)
res, err := s.rangeQuery(traceCtx, client, q)
if err != nil {
return nil, err
}
@ -140,7 +141,7 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
}
if q.ExemplarQuery {
res, err := s.exemplarQuery(traceCtx, client, q, headers)
res, err := s.exemplarQuery(traceCtx, client, q)
if err != nil {
// If exemplar query returns error, we want to only log it and
// continue with other results processing
@ -154,24 +155,24 @@ func (s *QueryData) fetch(ctx context.Context, client *client.Client, q *models.
return response, nil
}
func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
res, err := c.QueryRange(ctx, q, sdkHeaderToHttpHeader(headers))
func (s *QueryData) rangeQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
res, err := c.QueryRange(ctx, q)
if err != nil {
return nil, err
}
return s.parseResponse(ctx, q, res)
}
func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
res, err := c.QueryInstant(ctx, q, sdkHeaderToHttpHeader(headers))
func (s *QueryData) instantQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
res, err := c.QueryInstant(ctx, q)
if err != nil {
return nil, err
}
return s.parseResponse(ctx, q, res)
}
func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query, headers map[string]string) (*backend.DataResponse, error) {
res, err := c.QueryExemplars(ctx, q, sdkHeaderToHttpHeader(headers))
func (s *QueryData) exemplarQuery(ctx context.Context, c *client.Client, q *models.Query) (*backend.DataResponse, error) {
res, err := c.QueryExemplars(ctx, q)
if err != nil {
return nil, err
}
@ -185,11 +186,3 @@ func (s *QueryData) trace(ctx context.Context, q *models.Query) (context.Context
{Key: "stop_unixnano", Value: q.End, Kv: attribute.Key("stop_unixnano").Int64(q.End.UnixNano())},
})
}
func sdkHeaderToHttpHeader(headers map[string]string) http.Header {
httpHeader := make(http.Header, len(headers))
for key, val := range headers {
httpHeader.Set(key, val)
}
return httpHeader
}

View File

@ -16,7 +16,6 @@ import (
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
p "github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/httpclient"
@ -343,27 +342,6 @@ func TestPrometheus_parseTimeSeriesResponse(t *testing.T) {
})
}
func TestPrometheusCanonicalHeaders(t *testing.T) {
// Ensure headers are always canonicalized for all outgoing requests
b, err := json.Marshal(models.QueryModel{})
require.NoError(t, err)
query := backend.DataQuery{JSON: b}
tctx, err := setup(true)
require.NoError(t, err)
const idToken = "abc"
_, err = executeWithHeaders(tctx, query, queryResult{}, map[string]string{
"X-Id-Token": idToken,
"X-ID-Token": idToken,
"X-Other": "thing",
})
require.NoError(t, err)
assert.NotEmpty(t, tctx.httpProvider.req.Header)
// Check the request that hit the fake prometheus server to ensure headers are valid
assert.Equal(t, []string{idToken}, tctx.httpProvider.req.Header["X-Id-Token"])
assert.Empty(t, tctx.httpProvider.req.Header["X-ID-Token"]) //nolint:staticcheck
assert.Equal(t, []string{"thing"}, tctx.httpProvider.req.Header["X-Other"])
}
type queryResult struct {
Type p.ValueType `json:"resultType"`
Result interface{} `json:"result"`