mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Datasource: Shared HTTP client provider for core backend data sources and any data source using the data source proxy (#33439)
Uses new httpclient package from grafana-plugin-sdk-go introduced via grafana/grafana-plugin-sdk-go#328. Replaces the GetHTTPClient, GetTransport, GetTLSConfig methods defined on DataSource model. Longer-term the goal is to migrate core HTTP backend data sources to use the SDK contracts and using httpclient.Provider for creating HTTP clients and such. Co-authored-by: Arve Knudsen <arve.knudsen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7a83d1f9ff
commit
348e76fc8e
@@ -0,0 +1,105 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/metrics/metricutil"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
var datasourceRequestCounter = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: "grafana",
|
||||
Name: "datasource_request_total",
|
||||
Help: "A counter for outgoing requests for a datasource",
|
||||
},
|
||||
[]string{"datasource", "code", "method"},
|
||||
)
|
||||
|
||||
var datasourceRequestSummary = prometheus.NewSummaryVec(
|
||||
prometheus.SummaryOpts{
|
||||
Namespace: "grafana",
|
||||
Name: "datasource_request_duration_seconds",
|
||||
Help: "summary of outgoing datasource requests sent from Grafana",
|
||||
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
|
||||
}, []string{"datasource", "code", "method"},
|
||||
)
|
||||
|
||||
var datasourceResponseSummary = prometheus.NewSummaryVec(
|
||||
prometheus.SummaryOpts{
|
||||
Namespace: "grafana",
|
||||
Name: "datasource_response_size_bytes",
|
||||
Help: "summary of datasource response sizes returned to Grafana",
|
||||
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
|
||||
}, []string{"datasource"},
|
||||
)
|
||||
|
||||
var datasourceRequestsInFlight = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "grafana",
|
||||
Name: "datasource_request_in_flight",
|
||||
Help: "A gauge of outgoing datasource requests currently being sent by Grafana",
|
||||
},
|
||||
[]string{"datasource"},
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(datasourceRequestSummary,
|
||||
datasourceRequestCounter,
|
||||
datasourceRequestsInFlight,
|
||||
datasourceResponseSummary)
|
||||
}
|
||||
|
||||
const DataSourceMetricsMiddlewareName = "metrics"
|
||||
|
||||
var executeMiddlewareFunc = executeMiddleware
|
||||
|
||||
func DataSourceMetricsMiddleware() httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(DataSourceMetricsMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if opts.Labels == nil {
|
||||
return next
|
||||
}
|
||||
|
||||
datasourceName, exists := opts.Labels["datasource_name"]
|
||||
if !exists {
|
||||
return next
|
||||
}
|
||||
|
||||
datasourceLabelName, err := metricutil.SanitizeLabelName(datasourceName)
|
||||
// if the datasource named cannot be turned into a prometheus
|
||||
// label we will skip instrumenting these metrics.
|
||||
if err != nil {
|
||||
return next
|
||||
}
|
||||
|
||||
datasourceLabel := prometheus.Labels{"datasource": datasourceLabelName}
|
||||
|
||||
return executeMiddlewareFunc(next, datasourceLabel)
|
||||
})
|
||||
}
|
||||
|
||||
func executeMiddleware(next http.RoundTripper, datasourceLabel prometheus.Labels) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
requestCounter := datasourceRequestCounter.MustCurryWith(datasourceLabel)
|
||||
requestSummary := datasourceRequestSummary.MustCurryWith(datasourceLabel)
|
||||
requestInFlight := datasourceRequestsInFlight.With(datasourceLabel)
|
||||
responseSizeSummary := datasourceResponseSummary.With(datasourceLabel)
|
||||
|
||||
res, err := promhttp.InstrumentRoundTripperDuration(requestSummary,
|
||||
promhttp.InstrumentRoundTripperCounter(requestCounter,
|
||||
promhttp.InstrumentRoundTripperInFlight(requestInFlight, next))).
|
||||
RoundTrip(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// we avoid measuring contentlength less than zero because it indicates
|
||||
// that the content size is unknown. https://godoc.org/github.com/badu/http#Response
|
||||
if res != nil && res.ContentLength > 0 {
|
||||
responseSizeSummary.Observe(float64(res.ContentLength))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
})
|
||||
}
|
@@ -0,0 +1,130 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDataSourceMetricsMiddleware(t *testing.T) {
|
||||
t.Run("Without label options set should return next http.RoundTripper", func(t *testing.T) {
|
||||
origExecuteMiddlewareFunc := executeMiddlewareFunc
|
||||
executeMiddlewareCalled := false
|
||||
middlewareCalled := false
|
||||
executeMiddlewareFunc = func(next http.RoundTripper, datasourceLabel prometheus.Labels) http.RoundTripper {
|
||||
executeMiddlewareCalled = true
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
middlewareCalled = true
|
||||
return next.RoundTrip(r)
|
||||
})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
executeMiddlewareFunc = origExecuteMiddlewareFunc
|
||||
})
|
||||
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
mw := DataSourceMetricsMiddleware()
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, DataSourceMetricsMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
require.False(t, executeMiddlewareCalled)
|
||||
require.False(t, middlewareCalled)
|
||||
})
|
||||
|
||||
t.Run("Without data source name label options set should return next http.RoundTripper", func(t *testing.T) {
|
||||
origExecuteMiddlewareFunc := executeMiddlewareFunc
|
||||
executeMiddlewareCalled := false
|
||||
middlewareCalled := false
|
||||
executeMiddlewareFunc = func(next http.RoundTripper, datasourceLabel prometheus.Labels) http.RoundTripper {
|
||||
executeMiddlewareCalled = true
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
middlewareCalled = true
|
||||
return next.RoundTrip(r)
|
||||
})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
executeMiddlewareFunc = origExecuteMiddlewareFunc
|
||||
})
|
||||
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
mw := DataSourceMetricsMiddleware()
|
||||
rt := mw.CreateMiddleware(httpclient.Options{Labels: map[string]string{"test": "test"}}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, DataSourceMetricsMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
require.False(t, executeMiddlewareCalled)
|
||||
require.False(t, middlewareCalled)
|
||||
})
|
||||
|
||||
t.Run("With datasource name label options set should execute middleware", func(t *testing.T) {
|
||||
origExecuteMiddlewareFunc := executeMiddlewareFunc
|
||||
executeMiddlewareCalled := false
|
||||
datasourceLabels := prometheus.Labels{}
|
||||
middlewareCalled := false
|
||||
executeMiddlewareFunc = func(next http.RoundTripper, datasourceLabel prometheus.Labels) http.RoundTripper {
|
||||
executeMiddlewareCalled = true
|
||||
datasourceLabels = datasourceLabel
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
middlewareCalled = true
|
||||
return next.RoundTrip(r)
|
||||
})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
executeMiddlewareFunc = origExecuteMiddlewareFunc
|
||||
})
|
||||
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
mw := DataSourceMetricsMiddleware()
|
||||
rt := mw.CreateMiddleware(httpclient.Options{Labels: map[string]string{"datasource_name": "My Data Source 123"}}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, DataSourceMetricsMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
require.True(t, executeMiddlewareCalled)
|
||||
require.Len(t, datasourceLabels, 1)
|
||||
require.Equal(t, "My_Data_Source_123", datasourceLabels["datasource"])
|
||||
require.True(t, middlewareCalled)
|
||||
})
|
||||
}
|
@@ -0,0 +1,30 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var newProviderFunc = sdkhttpclient.NewProvider
|
||||
|
||||
// New creates a new HTTP client provider with pre-configured middlewares.
|
||||
func New(cfg *setting.Cfg) httpclient.Provider {
|
||||
userAgent := fmt.Sprintf("Grafana/%s", cfg.BuildVersion)
|
||||
middlewares := []sdkhttpclient.Middleware{
|
||||
DataSourceMetricsMiddleware(),
|
||||
SetUserAgentMiddleware(userAgent),
|
||||
sdkhttpclient.BasicAuthenticationMiddleware(),
|
||||
sdkhttpclient.CustomHeadersMiddleware(),
|
||||
}
|
||||
|
||||
if cfg.SigV4AuthEnabled {
|
||||
middlewares = append(middlewares, SigV4Middleware())
|
||||
}
|
||||
|
||||
return newProviderFunc(sdkhttpclient.ProviderOptions{
|
||||
Middlewares: middlewares,
|
||||
})
|
||||
}
|
@@ -0,0 +1,52 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPClientProvider(t *testing.T) {
|
||||
t.Run("When creating new provider and SigV4 is disabled should apply expected middleware", func(t *testing.T) {
|
||||
origNewProviderFunc := newProviderFunc
|
||||
providerOpts := []sdkhttpclient.ProviderOptions{}
|
||||
newProviderFunc = func(opts ...sdkhttpclient.ProviderOptions) *sdkhttpclient.Provider {
|
||||
providerOpts = opts
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
newProviderFunc = origNewProviderFunc
|
||||
})
|
||||
_ = New(&setting.Cfg{SigV4AuthEnabled: false})
|
||||
require.Len(t, providerOpts, 1)
|
||||
o := providerOpts[0]
|
||||
require.Len(t, o.Middlewares, 4)
|
||||
require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
|
||||
t.Run("When creating new provider and SigV4 is enabled should apply expected middleware", func(t *testing.T) {
|
||||
origNewProviderFunc := newProviderFunc
|
||||
providerOpts := []sdkhttpclient.ProviderOptions{}
|
||||
newProviderFunc = func(opts ...sdkhttpclient.ProviderOptions) *sdkhttpclient.Provider {
|
||||
providerOpts = opts
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
newProviderFunc = origNewProviderFunc
|
||||
})
|
||||
_ = New(&setting.Cfg{SigV4AuthEnabled: true})
|
||||
require.Len(t, providerOpts, 1)
|
||||
o := providerOpts[0]
|
||||
require.Len(t, o.Middlewares, 5)
|
||||
require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
require.Equal(t, SigV4MiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
|
||||
})
|
||||
}
|
36
pkg/infra/httpclient/httpclientprovider/sigv4_middleware.go
Normal file
36
pkg/infra/httpclient/httpclientprovider/sigv4_middleware.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-aws-sdk/pkg/sigv4"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
// SigV4MiddlewareName the middleware name used by SigV4Middleware.
|
||||
const SigV4MiddlewareName = "sigv4"
|
||||
|
||||
var newSigV4Func = sigv4.New
|
||||
|
||||
// SigV4Middleware applies AWS Signature Version 4 request signing for the outgoing request.
|
||||
func SigV4Middleware() httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(SigV4MiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if opts.SigV4 == nil {
|
||||
return next
|
||||
}
|
||||
|
||||
return newSigV4Func(
|
||||
&sigv4.Config{
|
||||
Service: opts.SigV4.Service,
|
||||
AccessKey: opts.SigV4.AccessKey,
|
||||
SecretKey: opts.SigV4.SecretKey,
|
||||
Region: opts.SigV4.Region,
|
||||
AssumeRoleARN: opts.SigV4.AssumeRoleARN,
|
||||
AuthType: opts.SigV4.AuthType,
|
||||
ExternalID: opts.SigV4.ExternalID,
|
||||
Profile: opts.SigV4.Profile,
|
||||
},
|
||||
next,
|
||||
)
|
||||
})
|
||||
}
|
@@ -0,0 +1,89 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-aws-sdk/pkg/sigv4"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSigV4Middleware(t *testing.T) {
|
||||
t.Run("Without sigv4 options set should return next http.RoundTripper", func(t *testing.T) {
|
||||
origSigV4Func := newSigV4Func
|
||||
newSigV4Called := false
|
||||
middlewareCalled := false
|
||||
newSigV4Func = func(config *sigv4.Config, next http.RoundTripper) http.RoundTripper {
|
||||
newSigV4Called = true
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
middlewareCalled = true
|
||||
return next.RoundTrip(r)
|
||||
})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
newSigV4Func = origSigV4Func
|
||||
})
|
||||
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
mw := SigV4Middleware()
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SigV4MiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
require.False(t, newSigV4Called)
|
||||
require.False(t, middlewareCalled)
|
||||
})
|
||||
|
||||
t.Run("With sigv4 options set should call sigv4 http.RoundTripper", func(t *testing.T) {
|
||||
origSigV4Func := newSigV4Func
|
||||
newSigV4Called := false
|
||||
middlewareCalled := false
|
||||
newSigV4Func = func(config *sigv4.Config, next http.RoundTripper) http.RoundTripper {
|
||||
newSigV4Called = true
|
||||
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
middlewareCalled = true
|
||||
return next.RoundTrip(r)
|
||||
})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
newSigV4Func = origSigV4Func
|
||||
})
|
||||
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("final")
|
||||
mw := SigV4Middleware()
|
||||
rt := mw.CreateMiddleware(httpclient.Options{SigV4: &httpclient.SigV4Config{}}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SigV4MiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
|
||||
|
||||
require.True(t, newSigV4Called)
|
||||
require.True(t, middlewareCalled)
|
||||
})
|
||||
}
|
18
pkg/infra/httpclient/httpclientprovider/testing.go
Normal file
18
pkg/infra/httpclient/httpclientprovider/testing.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
type testContext struct {
|
||||
callChain []string
|
||||
}
|
||||
|
||||
func (c *testContext) createRoundTripper(name string) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
c.callChain = append(c.callChain, name)
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
})
|
||||
}
|
@@ -0,0 +1,27 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
// SetUserAgentMiddlewareName is the middleware name used by SetUserAgentMiddleware.
|
||||
const SetUserAgentMiddlewareName = "user-agent"
|
||||
|
||||
// SetUserAgentMiddleware is middleware that sets the HTTP header User-Agent on the outgoing request.
|
||||
// If User-Agent already set, it will not be overridden by this middleware.
|
||||
func SetUserAgentMiddleware(userAgent string) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(SetUserAgentMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if userAgent == "" {
|
||||
return next
|
||||
}
|
||||
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
})
|
||||
}
|
@@ -0,0 +1,82 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomHeadersMiddleware(t *testing.T) {
|
||||
t.Run("Without user agent set should return next http.RoundTripper", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("finalrt")
|
||||
mw := SetUserAgentMiddleware("")
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SetUserAgentMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
|
||||
})
|
||||
|
||||
t.Run("With user agent set should apply HTTP headers to the request", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("final")
|
||||
mw := SetUserAgentMiddleware("Grafana/8.0.0")
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SetUserAgentMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
|
||||
|
||||
require.Equal(t, "Grafana/8.0.0", req.Header.Get("User-Agent"))
|
||||
})
|
||||
|
||||
t.Run("With user agent set, but request already has User-Agent header set should not apply HTTP headers to the request", func(t *testing.T) {
|
||||
ctx := &testContext{}
|
||||
finalRoundTripper := ctx.createRoundTripper("final")
|
||||
mw := SetUserAgentMiddleware("Grafana/8.0.0")
|
||||
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
|
||||
require.NotNil(t, rt)
|
||||
middlewareName, ok := mw.(httpclient.MiddlewareName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, SetUserAgentMiddlewareName, middlewareName.MiddlewareName())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("User-Agent", "ua")
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
if res.Body != nil {
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
require.Len(t, ctx.callChain, 1)
|
||||
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
|
||||
|
||||
require.Equal(t, "ua", req.Header.Get("User-Agent"))
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user