diff --git a/pkg/tsdb/prometheus/custom_query_params_middleware.go b/pkg/tsdb/prometheus/custom_query_params_middleware.go index 88f23d10c9c..b5c75c1cb83 100644 --- a/pkg/tsdb/prometheus/custom_query_params_middleware.go +++ b/pkg/tsdb/prometheus/custom_query_params_middleware.go @@ -1,12 +1,11 @@ package prometheus import ( - "fmt" "net/http" "net/url" - "strings" sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/log" ) const ( @@ -14,7 +13,7 @@ const ( customQueryParametersKey = "customQueryParameters" ) -func customQueryParametersMiddleware() sdkhttpclient.Middleware { +func customQueryParametersMiddleware(logger log.Logger) sdkhttpclient.Middleware { return sdkhttpclient.NamedMiddlewareFunc(customQueryParametersMiddlewareName, func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { customQueryParamsVal, exists := opts.CustomOptions[customQueryParametersKey] if !exists { @@ -25,22 +24,20 @@ func customQueryParametersMiddleware() sdkhttpclient.Middleware { return next } + values, err := url.ParseQuery(customQueryParams) + if err != nil { + logger.Error("Failed to parse custom query parameters, skipping middleware", "error", err) + return next + } + return sdkhttpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - params := url.Values{} - for _, param := range strings.Split(customQueryParams, "&") { - parts := strings.Split(param, "=") - if len(parts) == 1 { - // This is probably a mistake on the users part in defining the params but we don't want to crash. - params.Add(parts[0], "") - } else { - params.Add(parts[0], parts[1]) + q := req.URL.Query() + for k, keyValues := range values { + for _, value := range keyValues { + q.Add(k, value) } } - if req.URL.RawQuery != "" { - req.URL.RawQuery = fmt.Sprintf("%s&%s", req.URL.RawQuery, params.Encode()) - } else { - req.URL.RawQuery = params.Encode() - } + req.URL.RawQuery = q.Encode() return next.RoundTrip(req) }) diff --git a/pkg/tsdb/prometheus/custom_query_params_middleware_test.go b/pkg/tsdb/prometheus/custom_query_params_middleware_test.go index 444ffcf3d5f..643b8c3753c 100644 --- a/pkg/tsdb/prometheus/custom_query_params_middleware_test.go +++ b/pkg/tsdb/prometheus/custom_query_params_middleware_test.go @@ -2,9 +2,12 @@ package prometheus import ( "net/http" + "net/url" + "strings" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/log" "github.com/stretchr/testify/require" ) @@ -16,7 +19,7 @@ func TestCustomQueryParametersMiddleware(t *testing.T) { }) t.Run("Without custom query parameters set should not apply middleware", func(t *testing.T) { - mw := customQueryParametersMiddleware() + mw := customQueryParametersMiddleware(log.New("test")) rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) require.NotNil(t, rt) middlewareName, ok := mw.(httpclient.MiddlewareName) @@ -36,7 +39,7 @@ func TestCustomQueryParametersMiddleware(t *testing.T) { }) t.Run("Without custom query parameters set as string should not apply middleware", func(t *testing.T) { - mw := customQueryParametersMiddleware() + mw := customQueryParametersMiddleware(log.New("test")) rt := mw.CreateMiddleware(httpclient.Options{ CustomOptions: map[string]interface{}{ customQueryParametersKey: 64, @@ -60,7 +63,7 @@ func TestCustomQueryParametersMiddleware(t *testing.T) { }) t.Run("With custom query parameters set as empty string should not apply middleware", func(t *testing.T) { - mw := customQueryParametersMiddleware() + mw := customQueryParametersMiddleware(log.New("test")) rt := mw.CreateMiddleware(httpclient.Options{ CustomOptions: map[string]interface{}{ customQueryParametersKey: "", @@ -83,8 +86,32 @@ func TestCustomQueryParametersMiddleware(t *testing.T) { require.Equal(t, "http://test.com/query?hello=name", req.URL.String()) }) - t.Run("With custom query parameters set as string should apply middleware", func(t *testing.T) { - mw := customQueryParametersMiddleware() + t.Run("With custom query parameters set as invalid query string should not apply middleware", func(t *testing.T) { + mw := customQueryParametersMiddleware(log.New("test")) + rt := mw.CreateMiddleware(httpclient.Options{ + CustomOptions: map[string]interface{}{ + customQueryParametersKey: "custom=%%abc&test=abc", + }, + }, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query?hello=name", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?hello=name", req.URL.String()) + }) + + t.Run("With custom query parameters set should apply middleware for request URL containing query parameters ", func(t *testing.T) { + mw := customQueryParametersMiddleware(log.New("test")) rt := mw.CreateMiddleware(httpclient.Options{ CustomOptions: map[string]interface{}{ customQueryParametersKey: "custom=par/am&second=f oo", @@ -104,6 +131,36 @@ func TestCustomQueryParametersMiddleware(t *testing.T) { require.NoError(t, res.Body.Close()) } - require.Equal(t, "http://test.com/query?hello=name&custom=par%2Fam&second=f+oo", req.URL.String()) + require.True(t, strings.HasPrefix(req.URL.String(), "http://test.com/query?")) + + q := req.URL.Query() + require.Len(t, q, 3) + require.Equal(t, "name", url.QueryEscape(q.Get("hello"))) + require.Equal(t, "par%2Fam", url.QueryEscape(q.Get("custom"))) + require.Equal(t, "f+oo", url.QueryEscape(q.Get("second"))) + }) + + t.Run("With custom query parameters set should apply middleware for request URL not containing query parameters", func(t *testing.T) { + mw := customQueryParametersMiddleware(log.New("test")) + rt := mw.CreateMiddleware(httpclient.Options{ + CustomOptions: map[string]interface{}{ + customQueryParametersKey: "custom=par/am&second=f oo", + }, + }, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?custom=par%2Fam&second=f+oo", req.URL.String()) }) } diff --git a/pkg/tsdb/prometheus/prometheus.go b/pkg/tsdb/prometheus/prometheus.go index edf99ddf6e3..0865e9fac05 100644 --- a/pkg/tsdb/prometheus/prometheus.go +++ b/pkg/tsdb/prometheus/prometheus.go @@ -37,7 +37,7 @@ type PrometheusExecutor struct { //nolint: staticcheck // plugins.DataPlugin deprecated func New(provider httpclient.Provider) func(*models.DataSource) (plugins.DataPlugin, error) { return func(dsInfo *models.DataSource) (plugins.DataPlugin, error) { - transport, err := dsInfo.GetHTTPTransport(provider, customQueryParametersMiddleware()) + transport, err := dsInfo.GetHTTPTransport(provider, customQueryParametersMiddleware(plog)) if err != nil { return nil, err }