diff --git a/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware.go b/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware.go new file mode 100644 index 00000000000..0075578ae8a --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware.go @@ -0,0 +1,34 @@ +package httpclientprovider + +import ( + "net/http" + + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/pluginsintegration/clientmiddleware" + "github.com/grafana/grafana/pkg/setting" +) + +const GrafanaRequestIDHeaderMiddlewareName = "grafana-request-id-header-middleware" + +func GrafanaRequestIDHeaderMiddleware(cfg *setting.Cfg, logger log.Logger) sdkhttpclient.Middleware { + return sdkhttpclient.NamedMiddlewareFunc(GrafanaRequestIDHeaderMiddlewareName, func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { + return sdkhttpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header.Get(clientmiddleware.GrafanaRequestID) != "" { + logger.Debug("Request already has a Grafana request ID header", "request_id", req.Header.Get(clientmiddleware.GrafanaRequestID)) + return next.RoundTrip(req) + } + + if !clientmiddleware.IsRequestURLInAllowList(req.URL, cfg) { + logger.Debug("Data source URL not among the allow-listed URLs", "url", req.URL.String()) + return next.RoundTrip(req) + } + + for k, v := range clientmiddleware.GetGrafanaRequestIDHeaders(req, cfg, logger) { + req.Header.Set(k, v) + } + + return next.RoundTrip(req) + }) + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware_test.go new file mode 100644 index 00000000000..92c995915b4 --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/grafana_request_id_header_middleware_test.go @@ -0,0 +1,113 @@ +package httpclientprovider + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/url" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/pluginsintegration/clientmiddleware" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/require" +) + +func TestGrafanaRequestIDHeaderMiddleware(t *testing.T) { + testCases := []struct { + description string + allowedURLs []*url.URL + requestURL string + remoteAddress string + expectGrafanaRequestIDHeaders bool + expectPrivateRequestHeader bool + }{ + { + description: "With target URL in the allowed URL list and remote address specified, should add headers to the request but the request should not be marked as private", + allowedURLs: []*url.URL{{ + Scheme: "https", + Host: "grafana.com", + }}, + requestURL: "https://grafana.com/api/some/path", + remoteAddress: "1.2.3.4", + expectGrafanaRequestIDHeaders: true, + expectPrivateRequestHeader: false, + }, + { + description: "With target URL in the allowed URL list and remote address not specified, should add headers to the request and the request should be marked as private", + allowedURLs: []*url.URL{{ + Scheme: "https", + Host: "grafana.com", + }}, + requestURL: "https://grafana.com/api/some/path", + expectGrafanaRequestIDHeaders: true, + expectPrivateRequestHeader: true, + }, + { + description: "With target URL not in the allowed URL list, should not add headers to the request", + allowedURLs: []*url.URL{{ + Scheme: "https", + Host: "grafana.com", + }}, + requestURL: "https://fake-grafana.com/api/some/path", + expectGrafanaRequestIDHeaders: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + ctx := &testContext{} + finalRoundTripper := ctx.createRoundTripper("final") + cfg := setting.NewCfg() + cfg.IPRangeACEnabled = false + cfg.IPRangeACAllowedURLs = tc.allowedURLs + cfg.IPRangeACSecretKey = "secret" + mw := GrafanaRequestIDHeaderMiddleware(cfg, log.New("test")) + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, GrafanaRequestIDHeaderMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, tc.requestURL, nil) + require.NoError(t, err) + req.RemoteAddr = tc.remoteAddress + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + require.Len(t, ctx.callChain, 1) + require.ElementsMatch(t, []string{"final"}, ctx.callChain) + + if !tc.expectGrafanaRequestIDHeaders { + require.Len(t, req.Header.Values(clientmiddleware.GrafanaRequestID), 0) + require.Len(t, req.Header.Values(clientmiddleware.GrafanaSignedRequestID), 0) + } else { + require.Len(t, req.Header.Values(clientmiddleware.GrafanaRequestID), 1) + require.Len(t, req.Header.Values(clientmiddleware.GrafanaSignedRequestID), 1) + requestID := req.Header.Get(clientmiddleware.GrafanaRequestID) + + instance := hmac.New(sha256.New, []byte(cfg.IPRangeACSecretKey)) + _, err = instance.Write([]byte(requestID)) + require.NoError(t, err) + computed := hex.EncodeToString(instance.Sum(nil)) + + require.Equal(t, req.Header.Get(clientmiddleware.GrafanaSignedRequestID), computed) + + if tc.remoteAddress == "" { + require.Equal(t, req.Header.Get(clientmiddleware.GrafanaInternalRequest), "true") + } else { + require.Len(t, req.Header.Values(clientmiddleware.XRealIPHeader), 1) + require.Equal(t, req.Header.Get(clientmiddleware.XRealIPHeader), tc.remoteAddress) + + // Internal header should not be set + require.Len(t, req.Header.Values(clientmiddleware.GrafanaInternalRequest), 0) + } + } + }) + } +} diff --git a/pkg/infra/httpclient/httpclientprovider/http_client_provider.go b/pkg/infra/httpclient/httpclientprovider/http_client_provider.go index 16c7b197049..69a9f2a1c73 100644 --- a/pkg/infra/httpclient/httpclientprovider/http_client_provider.go +++ b/pkg/infra/httpclient/httpclientprovider/http_client_provider.go @@ -39,6 +39,10 @@ func New(cfg *setting.Cfg, validator validations.PluginRequestValidator, tracer middlewares = append(middlewares, HTTPLoggerMiddleware(cfg.PluginSettings)) } + if cfg.IPRangeACEnabled { + middlewares = append(middlewares, GrafanaRequestIDHeaderMiddleware(cfg, logger)) + } + setDefaultTimeoutOptions(cfg) return newProviderFunc(sdkhttpclient.ProviderOptions{ diff --git a/pkg/services/pluginsintegration/clientmiddleware/grafana_request_id_header_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/grafana_request_id_header_middleware.go index 7637c8b0513..15c2a1c1ccc 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/grafana_request_id_header_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/grafana_request_id_header_middleware.go @@ -5,6 +5,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "net/http" "net/url" "github.com/google/uuid" @@ -55,45 +56,63 @@ func (m *HostedGrafanaACHeaderMiddleware) applyGrafanaRequestIDHeader(ctx contex m.log.Debug("Failed to parse data source URL", "error", err) return } - foundMatch := false - for _, allowedURL := range m.cfg.IPRangeACAllowedURLs { - // Only look at the scheme and host, ignore the path - if allowedURL.Host == dsBaseURL.Host && allowedURL.Scheme == dsBaseURL.Scheme { - foundMatch = true - break - } - } - if !foundMatch { + if !IsRequestURLInAllowList(dsBaseURL, m.cfg) { m.log.Debug("Data source URL not among the allow-listed URLs", "url", dsBaseURL.String()) return } + var req *http.Request + reqCtx := contexthandler.FromContext(ctx) + if reqCtx != nil { + req = reqCtx.Req + } + for k, v := range GetGrafanaRequestIDHeaders(req, m.cfg, m.log) { + h.SetHTTPHeader(k, v) + } +} + +func IsRequestURLInAllowList(url *url.URL, cfg *setting.Cfg) bool { + for _, allowedURL := range cfg.IPRangeACAllowedURLs { + // Only look at the scheme and host, ignore the path + if allowedURL.Host == url.Host && allowedURL.Scheme == url.Scheme { + return true + } + } + return false +} + +func GetGrafanaRequestIDHeaders(req *http.Request, cfg *setting.Cfg, logger log.Logger) map[string]string { // Generate a new Grafana request ID and sign it with the secret key uid, err := uuid.NewRandom() if err != nil { - m.log.Debug("Failed to generate Grafana request ID", "error", err) - return + logger.Debug("Failed to generate Grafana request ID", "error", err) + return nil } grafanaRequestID := uid.String() - hmac := hmac.New(sha256.New, []byte(m.cfg.IPRangeACSecretKey)) + hmac := hmac.New(sha256.New, []byte(cfg.IPRangeACSecretKey)) if _, err := hmac.Write([]byte(grafanaRequestID)); err != nil { - m.log.Debug("Failed to sign IP range access control header", "error", err) - return + logger.Debug("Failed to sign IP range access control header", "error", err) + return nil } signedGrafanaRequestID := hex.EncodeToString(hmac.Sum(nil)) - h.SetHTTPHeader(GrafanaSignedRequestID, signedGrafanaRequestID) - h.SetHTTPHeader(GrafanaRequestID, grafanaRequestID) - reqCtx := contexthandler.FromContext(ctx) - if reqCtx != nil && reqCtx.Req != nil { - remoteAddress := web.RemoteAddr(reqCtx.Req) - if remoteAddress != "" { - h.SetHTTPHeader(XRealIPHeader, remoteAddress) - return - } + headers := make(map[string]string) + headers[GrafanaRequestID] = grafanaRequestID + headers[GrafanaSignedRequestID] = signedGrafanaRequestID + + // If the remote address is not specified, treat the request as internal + remoteAddress := "" + if req != nil { + remoteAddress = web.RemoteAddr(req) } - h.SetHTTPHeader(GrafanaInternalRequest, "true") + if remoteAddress != "" { + headers[XRealIPHeader] = remoteAddress + } else { + headers[GrafanaInternalRequest] = "true" + } + + return headers } func (m *HostedGrafanaACHeaderMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index d6a3b52bec1..916b296540b 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -1944,6 +1944,9 @@ func (cfg *Cfg) readDataSourceSecuritySettings() { datasources := cfg.Raw.Section("datasources.ip_range_security") cfg.IPRangeACEnabled = datasources.Key("enabled").MustBool(false) cfg.IPRangeACSecretKey = datasources.Key("secret_key").MustString("") + if cfg.IPRangeACEnabled && cfg.IPRangeACSecretKey == "" { + cfg.Logger.Error("IP range access control is enabled but no secret key is set") + } allowedURLString := datasources.Key("allow_list").MustString("") for _, urlString := range util.SplitString(allowedURLString) { allowedURL, err := url.Parse(urlString)