From 4539c33fce5ef23badb08ebcbc09cb0cecb1f539 Mon Sep 17 00:00:00 2001 From: Will Browne Date: Mon, 3 Oct 2022 21:05:03 +0200 Subject: [PATCH] Plugin fixes (#562) * Plugins: Remove support for V1 manifests * Plugins: Make proxy endpoints not leak sensitive HTTP headers * Security: Fix do not forward login cookie in outgoing requests --- pkg/api/datasources.go | 2 +- pkg/api/metrics_test.go | 7 +-- pkg/api/plugin_resource.go | 11 +++- pkg/api/pluginproxy/ds_proxy.go | 2 +- pkg/api/plugins_test.go | 8 +++ .../forwarded_cookie_middleware_test.go | 9 +++- .../forwarded_cookies_middleware.go | 4 +- pkg/middleware/middleware_basic_auth_test.go | 6 +++ pkg/middleware/middleware_jwt_auth_test.go | 4 ++ pkg/middleware/middleware_test.go | 5 ++ pkg/plugins/manager/loader/loader_test.go | 8 +-- pkg/plugins/manager/signature/manifest.go | 30 ++++++----- pkg/services/contexthandler/auth_jwt.go | 3 ++ pkg/services/contexthandler/contexthandler.go | 54 ++++++++++++++++++- pkg/services/query/query.go | 4 +- pkg/services/query/query_test.go | 3 +- pkg/util/proxyutil/proxyutil.go | 23 ++++++-- pkg/util/proxyutil/proxyutil_test.go | 16 +++++- pkg/util/proxyutil/reverse_proxy.go | 8 +++ pkg/util/proxyutil/reverse_proxy_test.go | 7 +++ 20 files changed, 176 insertions(+), 38 deletions(-) diff --git a/pkg/api/datasources.go b/pkg/api/datasources.go index 014f2378e02..2ad595a1ba9 100644 --- a/pkg/api/datasources.go +++ b/pkg/api/datasources.go @@ -827,7 +827,7 @@ func (hs *HTTPServer) checkDatasourceHealth(c *models.ReqContext, ds *datasource } } - proxyutil.ClearCookieHeader(c.Req, ds.AllowedCookies()) + proxyutil.ClearCookieHeader(c.Req, ds.AllowedCookies(), []string{hs.Cfg.LoginCookieName}) if cookieStr := c.Req.Header.Get("Cookie"); cookieStr != "" { req.Headers["Cookie"] = cookieStr } diff --git a/pkg/api/metrics_test.go b/pkg/api/metrics_test.go index afb7fb303b6..1fd5986f6f4 100644 --- a/pkg/api/metrics_test.go +++ b/pkg/api/metrics_test.go @@ -26,6 +26,7 @@ import ( "github.com/grafana/grafana/pkg/services/query" "github.com/grafana/grafana/pkg/services/quota/quotatest" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util/errutil" "github.com/grafana/grafana/pkg/web/webtest" ) @@ -59,7 +60,7 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) // `/ds/query` endpoint test func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) { qds := query.ProvideService( - nil, + setting.NewCfg(), nil, nil, &fakePluginRequestValidator{}, @@ -108,7 +109,7 @@ func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) { func TestAPIEndpoint_Metrics_PluginDecryptionFailure(t *testing.T) { qds := query.ProvideService( - nil, + setting.NewCfg(), nil, nil, &fakePluginRequestValidator{}, @@ -271,7 +272,7 @@ func TestDataSourceQueryError(t *testing.T) { err := r.Add(context.Background(), p) require.NoError(t, err) hs.queryDataService = query.ProvideService( - nil, + setting.NewCfg(), &fakeDatasources.FakeCacheService{}, nil, &fakePluginRequestValidator{}, diff --git a/pkg/api/plugin_resource.go b/pkg/api/plugin_resource.go index 6dfaf912ab4..bccbb034ea8 100644 --- a/pkg/api/plugin_resource.go +++ b/pkg/api/plugin_resource.go @@ -14,6 +14,7 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins/backendplugin" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/util/proxyutil" "github.com/grafana/grafana/pkg/web" @@ -117,7 +118,15 @@ func (hs *HTTPServer) makePluginResourceRequest(w http.ResponseWriter, req *http hs.log.Warn("failed to unpack JSONData in datasource instance settings", "err", err) } } - proxyutil.ClearCookieHeader(req, keepCookieModel.KeepCookies) + + list := contexthandler.AuthHTTPHeaderListFromContext(req.Context()) + if list != nil { + for _, name := range list.Items { + req.Header.Del(name) + } + } + + proxyutil.ClearCookieHeader(req, keepCookieModel.KeepCookies, []string{hs.Cfg.LoginCookieName}) proxyutil.PrepareProxyRequest(req) body, err := io.ReadAll(req.Body) diff --git a/pkg/api/pluginproxy/ds_proxy.go b/pkg/api/pluginproxy/ds_proxy.go index 89ab1d0a972..c48f6a900ad 100644 --- a/pkg/api/pluginproxy/ds_proxy.go +++ b/pkg/api/pluginproxy/ds_proxy.go @@ -224,7 +224,7 @@ func (proxy *DataSourceProxy) director(req *http.Request) { applyUserHeader(proxy.cfg.SendUserHeader, req, proxy.ctx.SignedInUser) - proxyutil.ClearCookieHeader(req, proxy.ds.AllowedCookies()) + proxyutil.ClearCookieHeader(req, proxy.ds.AllowedCookies(), []string{proxy.cfg.LoginCookieName}) req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion)) jsonData := make(map[string]interface{}) diff --git a/pkg/api/plugins_test.go b/pkg/api/plugins_test.go index 1965d578340..043b6f2a56e 100644 --- a/pkg/api/plugins_test.go +++ b/pkg/api/plugins_test.go @@ -23,6 +23,7 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" ac "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/pluginsettings" "github.com/grafana/grafana/pkg/services/quota/quotatest" @@ -320,6 +321,12 @@ func TestMakePluginResourceRequest(t *testing.T) { pluginClient: &fakePluginClient{}, } req := httptest.NewRequest(http.MethodGet, "/", nil) + + const customHeader = "X-CUSTOM" + req.Header.Set(customHeader, "val") + ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) + req = req.WithContext(ctx) + resp := httptest.NewRecorder() pCtx := backend.PluginContext{} err := hs.makePluginResourceRequest(resp, req, pCtx) @@ -332,6 +339,7 @@ func TestMakePluginResourceRequest(t *testing.T) { } require.Equal(t, "sandbox", resp.Header().Get("Content-Security-Policy")) + require.Empty(t, req.Header.Get(customHeader)) } func callGetPluginAsset(sc *scenarioContext) { diff --git a/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go index 07753e0c0f5..c9eb1767e01 100644 --- a/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go +++ b/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go @@ -13,6 +13,7 @@ func TestForwardedCookiesMiddleware(t *testing.T) { tcs := []struct { desc string allowedCookies []string + disallowedCookies []string expectedCookieHeader string }{ { @@ -30,6 +31,12 @@ func TestForwardedCookiesMiddleware(t *testing.T) { allowedCookies: []string{"c1", "c3"}, expectedCookieHeader: "c1=1; c3=3", }, + { + desc: "When provided with allowed and not allowed cookies should populate Cookie header", + allowedCookies: []string{"c1", "c3"}, + disallowedCookies: []string{"c1"}, + expectedCookieHeader: "c3=3", + }, } for _, tc := range tcs { @@ -41,7 +48,7 @@ func TestForwardedCookiesMiddleware(t *testing.T) { {Name: "c2", Value: "2"}, {Name: "c3", Value: "3"}, } - mw := httpclientprovider.ForwardedCookiesMiddleware(forwarded, tc.allowedCookies) + mw := httpclientprovider.ForwardedCookiesMiddleware(forwarded, tc.allowedCookies, tc.disallowedCookies) opts := httpclient.Options{} rt := mw.CreateMiddleware(opts, finalRoundTripper) require.NotNil(t, rt) diff --git a/pkg/infra/httpclient/httpclientprovider/forwarded_cookies_middleware.go b/pkg/infra/httpclient/httpclientprovider/forwarded_cookies_middleware.go index 7d0ef6f2ef5..90c8314bcad 100644 --- a/pkg/infra/httpclient/httpclientprovider/forwarded_cookies_middleware.go +++ b/pkg/infra/httpclient/httpclientprovider/forwarded_cookies_middleware.go @@ -11,13 +11,13 @@ const ForwardedCookiesMiddlewareName = "forwarded-cookies" // ForwardedCookiesMiddleware middleware that sets Cookie header on the // outgoing request, if forwarded cookies configured/provided. -func ForwardedCookiesMiddleware(forwardedCookies []*http.Cookie, allowedCookies []string) httpclient.Middleware { +func ForwardedCookiesMiddleware(forwardedCookies []*http.Cookie, allowedCookies []string, disallowedCookies []string) httpclient.Middleware { return httpclient.NamedMiddlewareFunc(ForwardedCookiesMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { for _, cookie := range forwardedCookies { req.AddCookie(cookie) } - proxyutil.ClearCookieHeader(req, allowedCookies) + proxyutil.ClearCookieHeader(req, allowedCookies, disallowedCookies) return next.RoundTrip(req) }) }) diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go index c792ea15f23..215a819f79e 100644 --- a/pkg/middleware/middleware_basic_auth_test.go +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -38,6 +38,9 @@ func TestMiddlewareBasicAuth(t *testing.T) { assert.True(t, sc.context.IsSignedIn) assert.Equal(t, orgID, sc.context.OrgID) assert.Equal(t, org.RoleEditor, sc.context.OrgRole) + list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context()) + require.NotNil(t, list) + require.EqualValues(t, []string{"Authorization"}, list.Items) }, configure) middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) { @@ -71,6 +74,9 @@ func TestMiddlewareBasicAuth(t *testing.T) { assert.True(t, sc.context.IsSignedIn) assert.Equal(t, id, sc.context.UserID) + list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context()) + require.NotNil(t, list) + require.EqualValues(t, []string{"Authorization"}, list.Items) }, configure) middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) { diff --git a/pkg/middleware/middleware_jwt_auth_test.go b/pkg/middleware/middleware_jwt_auth_test.go index 7e791b9b720..abb35ae7469 100644 --- a/pkg/middleware/middleware_jwt_auth_test.go +++ b/pkg/middleware/middleware_jwt_auth_test.go @@ -7,6 +7,7 @@ import ( "github.com/grafana/grafana/pkg/services/org" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/contexthandler" @@ -75,6 +76,9 @@ func TestMiddlewareJWTAuth(t *testing.T) { assert.Equal(t, orgID, sc.context.OrgID) assert.Equal(t, id, sc.context.UserID) assert.Equal(t, myUsername, sc.context.Login) + list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context()) + require.NotNil(t, list) + require.EqualValues(t, []string{sc.cfg.JWTAuthHeaderName}, list.Items) }, configure, configureUsernameClaim) middlewareScenario(t, "Valid token with bearer in authorization header", func(t *testing.T, sc *scenarioContext) { diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index d264a65f80a..855c0ee514d 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -422,6 +422,11 @@ func TestMiddlewareContext(t *testing.T) { assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserID) assert.Equal(t, orgID, sc.context.OrgID) + list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context()) + require.NotNil(t, list) + require.Contains(t, list.Items, sc.cfg.AuthProxyHeaderName) + require.Contains(t, list.Items, "X-WEBAUTH-GROUPS") + require.Contains(t, list.Items, "X-WEBAUTH-ROLE") }, func(cfg *setting.Cfg) { configure(cfg) cfg.LDAPEnabled = false diff --git a/pkg/plugins/manager/loader/loader_test.go b/pkg/plugins/manager/loader/loader_test.go index 9247d23d479..1dc7a03c902 100644 --- a/pkg/plugins/manager/loader/loader_test.go +++ b/pkg/plugins/manager/loader/loader_test.go @@ -298,7 +298,7 @@ func TestLoader_Load(t *testing.T) { }, }, { - name: "Load an unsigned plugin with modified signature (production)", + name: "Load a plugin with v1 manifest should return signatureInvalid", class: plugins.External, cfg: &config.Cfg{}, pluginPaths: []string{"../testdata/lacking-files"}, @@ -306,12 +306,12 @@ func TestLoader_Load(t *testing.T) { pluginErrors: map[string]*plugins.Error{ "test-datasource": { PluginID: "test-datasource", - ErrorCode: "signatureModified", + ErrorCode: "signatureInvalid", }, }, }, { - name: "Load an unsigned plugin with modified signature using PluginsAllowUnsigned config (production) still includes a signing error", + name: "Load a plugin with v1 manifest using PluginsAllowUnsigned config (production) should return signatureInvali", class: plugins.External, cfg: &config.Cfg{ PluginsAllowUnsigned: []string{"test-datasource"}, @@ -321,7 +321,7 @@ func TestLoader_Load(t *testing.T) { pluginErrors: map[string]*plugins.Error{ "test-datasource": { PluginID: "test-datasource", - ErrorCode: "signatureModified", + ErrorCode: "signatureInvalid", }, }, }, diff --git a/pkg/plugins/manager/signature/manifest.go b/pkg/plugins/manager/signature/manifest.go index a1086daeddc..858ccf57910 100644 --- a/pkg/plugins/manager/signature/manifest.go +++ b/pkg/plugins/manager/signature/manifest.go @@ -132,6 +132,12 @@ func Calculate(mlog log.Logger, plugin *plugins.Plugin) (plugins.Signature, erro }, nil } + if !manifest.isV2() { + return plugins.Signature{ + Status: plugins.SignatureInvalid, + }, nil + } + // Make sure the versions all match if manifest.Plugin != plugin.ID || manifest.Version != plugin.Info.Version { return plugins.Signature{ @@ -167,21 +173,19 @@ func Calculate(mlog log.Logger, plugin *plugins.Plugin) (plugins.Signature, erro manifestFiles[p] = struct{}{} } - if manifest.isV2() { - // Track files missing from the manifest - var unsignedFiles []string - for _, f := range pluginFiles { - if _, exists := manifestFiles[f]; !exists { - unsignedFiles = append(unsignedFiles, f) - } + // Track files missing from the manifest + var unsignedFiles []string + for _, f := range pluginFiles { + if _, exists := manifestFiles[f]; !exists { + unsignedFiles = append(unsignedFiles, f) } + } - if len(unsignedFiles) > 0 { - mlog.Warn("The following files were not included in the signature", "plugin", plugin.ID, "files", unsignedFiles) - return plugins.Signature{ - Status: plugins.SignatureModified, - }, nil - } + if len(unsignedFiles) > 0 { + mlog.Warn("The following files were not included in the signature", "plugin", plugin.ID, "files", unsignedFiles) + return plugins.Signature{ + Status: plugins.SignatureModified, + }, nil } mlog.Debug("Plugin signature valid", "id", plugin.ID) diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index 898293a7d90..b671d816d68 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -142,6 +142,9 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64) return true } + newCtx := WithAuthHTTPHeader(ctx.Req.Context(), h.Cfg.JWTAuthHeaderName) + *ctx.Req = *ctx.Req.WithContext(newCtx) + ctx.SignedInUser = queryResult ctx.IsSignedIn = true diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index f343bbfdea8..49815079b72 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -258,6 +258,9 @@ func (h *ContextHandler) initContextWithAPIKey(reqContext *models.ReqContext) bo _, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAPIKey") defer span.End() + ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization") + *reqContext.Req = *reqContext.Req.WithContext(ctx) + var ( apikey *apikey.APIKey errKey error @@ -347,7 +350,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext, return false } - ctx, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithBasicAuth") + _, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithBasicAuth") defer span.End() username, password, err := util.DecodeBasicAuthHeader(header) @@ -356,12 +359,15 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext, return true } + ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization") + *reqContext.Req = *reqContext.Req.WithContext(ctx) + authQuery := models.LoginUserQuery{ Username: username, Password: password, Cfg: h.Cfg, } - if err := h.authenticator.AuthenticateUser(reqContext.Req.Context(), &authQuery); err != nil { + if err := h.authenticator.AuthenticateUser(ctx, &authQuery); err != nil { reqContext.Logger.Debug( "Failed to authorize the user", "username", username, @@ -610,6 +616,15 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, logger.Debug("Successfully got user info", "userID", user.UserID, "username", user.Login) + ctx := WithAuthHTTPHeader(reqContext.Req.Context(), h.Cfg.AuthProxyHeaderName) + for _, header := range h.Cfg.AuthProxyHeaders { + if header != "" { + ctx = WithAuthHTTPHeader(ctx, header) + } + } + + *reqContext.Req = *reqContext.Req.WithContext(ctx) + // Add user info to context reqContext.SignedInUser = user reqContext.IsSignedIn = true @@ -629,3 +644,38 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, return true } + +type authHTTPHeaderListContextKey struct{} + +var authHTTPHeaderListKey = authHTTPHeaderListContextKey{} + +// AuthHTTPHeaderList used to record HTTP headers that being when verifying authentication +// of an incoming HTTP request. +type AuthHTTPHeaderList struct { + Items []string +} + +// WithAuthHTTPHeader returns a copy of parent in which the named HTTP header will be included +// and later retrievable by AuthHTTPHeaderListFromContext. +func WithAuthHTTPHeader(parent context.Context, name string) context.Context { + list := AuthHTTPHeaderListFromContext(parent) + + if list == nil { + list = &AuthHTTPHeaderList{ + Items: []string{}, + } + } + + list.Items = append(list.Items, name) + + return context.WithValue(parent, authHTTPHeaderListKey, list) +} + +// AuthHTTPHeaderListFromContext returns the AuthHTTPHeaderList in a context.Context, if any, +// and will include any HTTP headers used when verifying authentication of an incoming HTTP request. +func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList { + if list, ok := c.Value(authHTTPHeaderListKey).(*AuthHTTPHeaderList); ok { + return list + } + return nil +} diff --git a/pkg/services/query/query.go b/pkg/services/query/query.go index 2cffca07861..de256f72c0a 100644 --- a/pkg/services/query/query.go +++ b/pkg/services/query/query.go @@ -171,7 +171,7 @@ func (s *Service) handleQueryData(ctx context.Context, user *user.SignedInUser, middlewares := []httpclient.Middleware{} if parsedReq.httpRequest != nil { middlewares = append(middlewares, - httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies()), + httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), []string{s.cfg.LoginCookieName}), ) } @@ -188,7 +188,7 @@ func (s *Service) handleQueryData(ctx context.Context, user *user.SignedInUser, } if parsedReq.httpRequest != nil { - proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies()) + proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), []string{s.cfg.LoginCookieName}) if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" { req.Headers["Cookie"] = cookieStr } diff --git a/pkg/services/query/query_test.go b/pkg/services/query/query_test.go index 33c8fc1f6bf..3e4ab7427a9 100644 --- a/pkg/services/query/query_test.go +++ b/pkg/services/query/query_test.go @@ -26,6 +26,7 @@ import ( secretsmng "github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" ) func TestQueryDataMultipleSources(t *testing.T) { @@ -197,7 +198,7 @@ func setup(t *testing.T) *testContext { dataSourceCache: dc, oauthTokenService: tc, pluginRequestValidator: rv, - queryService: query.ProvideService(nil, dc, exprService, rv, ds, pc, tc), + queryService: query.ProvideService(setting.NewCfg(), dc, exprService, rv, ds, pc, tc), } } diff --git a/pkg/util/proxyutil/proxyutil.go b/pkg/util/proxyutil/proxyutil.go index 3db22a1426e..ee56120cc6f 100644 --- a/pkg/util/proxyutil/proxyutil.go +++ b/pkg/util/proxyutil/proxyutil.go @@ -3,6 +3,7 @@ package proxyutil import ( "net" "net/http" + "sort" ) // PrepareProxyRequest prepares a request for being proxied. @@ -26,19 +27,31 @@ func PrepareProxyRequest(req *http.Request) { } } -// ClearCookieHeader clear cookie header, except for cookies specified to be kept. -func ClearCookieHeader(req *http.Request, keepCookiesNames []string) { - var keepCookies []*http.Cookie +// ClearCookieHeader clear cookie header, except for cookies specified to be kept (keepCookiesNames) if not in skipCookiesNames. +func ClearCookieHeader(req *http.Request, keepCookiesNames []string, skipCookiesNames []string) { + keepCookies := map[string]*http.Cookie{} for _, c := range req.Cookies() { for _, v := range keepCookiesNames { if c.Name == v { - keepCookies = append(keepCookies, c) + keepCookies[c.Name] = c } } } + for _, v := range skipCookiesNames { + delete(keepCookies, v) + } + req.Header.Del("Cookie") - for _, c := range keepCookies { + + sortedCookies := []string{} + for name := range keepCookies { + sortedCookies = append(sortedCookies, name) + } + sort.Strings(sortedCookies) + + for _, name := range sortedCookies { + c := keepCookies[name] req.AddCookie(c) } } diff --git a/pkg/util/proxyutil/proxyutil_test.go b/pkg/util/proxyutil/proxyutil_test.go index 5ff61ec1d29..03d816bbcd8 100644 --- a/pkg/util/proxyutil/proxyutil_test.go +++ b/pkg/util/proxyutil/proxyutil_test.go @@ -49,7 +49,7 @@ func TestClearCookieHeader(t *testing.T) { require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "cookie"}) - ClearCookieHeader(req, nil) + ClearCookieHeader(req, nil, nil) require.NotContains(t, req.Header, "Cookie") }) @@ -60,8 +60,20 @@ func TestClearCookieHeader(t *testing.T) { req.AddCookie(&http.Cookie{Name: "cookie2"}) req.AddCookie(&http.Cookie{Name: "cookie3"}) - ClearCookieHeader(req, []string{"cookie1", "cookie3"}) + ClearCookieHeader(req, []string{"cookie1", "cookie3"}, nil) require.Contains(t, req.Header, "Cookie") require.Equal(t, "cookie1=; cookie3=", req.Header.Get("Cookie")) }) + + t.Run("Clear cookie header with cookies to keep and skip should clear Cookie header and keep cookies", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{Name: "cookie1"}) + req.AddCookie(&http.Cookie{Name: "cookie2"}) + req.AddCookie(&http.Cookie{Name: "cookie3"}) + + ClearCookieHeader(req, []string{"cookie1", "cookie3"}, []string{"cookie3"}) + require.Contains(t, req.Header, "Cookie") + require.Equal(t, "cookie1=", req.Header.Get("Cookie")) + }) } diff --git a/pkg/util/proxyutil/reverse_proxy.go b/pkg/util/proxyutil/reverse_proxy.go index 58969dc29ee..bff95092298 100644 --- a/pkg/util/proxyutil/reverse_proxy.go +++ b/pkg/util/proxyutil/reverse_proxy.go @@ -10,6 +10,7 @@ import ( "time" glog "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/contexthandler" ) // StatusClientClosedRequest A non-standard status code introduced by nginx @@ -66,6 +67,13 @@ func NewReverseProxy(logger glog.Logger, director func(*http.Request), opts ...R // wrapDirector wraps a director and adds additional functionality. func wrapDirector(d func(*http.Request)) func(req *http.Request) { return func(req *http.Request) { + list := contexthandler.AuthHTTPHeaderListFromContext(req.Context()) + if list != nil { + for _, name := range list.Items { + req.Header.Del(name) + } + } + d(req) PrepareProxyRequest(req) diff --git a/pkg/util/proxyutil/reverse_proxy_test.go b/pkg/util/proxyutil/reverse_proxy_test.go index 19c7db9144d..b5dfee0c3e9 100644 --- a/pkg/util/proxyutil/reverse_proxy_test.go +++ b/pkg/util/proxyutil/reverse_proxy_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/stretchr/testify/require" ) @@ -30,6 +31,11 @@ func TestReverseProxy(t *testing.T) { req.Header.Set("Referer", "https://test.com/api") req.RemoteAddr = "10.0.0.1" + const customHeader = "X-CUSTOM" + req.Header.Set(customHeader, "val") + ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) + req = req.WithContext(ctx) + rp := NewReverseProxy(log.New("test"), func(req *http.Request) { req.Header.Set("X-KEY", "value") }) @@ -49,6 +55,7 @@ func TestReverseProxy(t *testing.T) { require.Empty(t, resp.Cookies()) require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy")) require.NoError(t, resp.Body.Close()) + require.Empty(t, actualReq.Header.Get(customHeader)) }) t.Run("When proxying a request using WithModifyResponse should call it before default ModifyResponse func", func(t *testing.T) {