diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f6f5271bdba..bd8bab5fa4a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -103,9 +103,11 @@ go.sum @grafana/backend-platform # Plugins /pkg/api/pluginproxy @grafana/plugins-platform-backend +/pkg/infra/httpclient @grafana/plugins-platform-backend /pkg/plugins @grafana/plugins-platform-backend /pkg/services/datasourceproxy @grafana/plugins-platform-backend /pkg/services/datasources @grafana/plugins-platform-backend +/pkg/services/pluginsintegration @grafana/plugins-platform-backend /pkg/plugins/pfs @grafana/plugins-platform-backend @grafana/grafana-as-code # Dashboard previews / crawler (behind feature flag) @@ -229,6 +231,7 @@ lerna.json @grafana/frontend-ops # Grafana Partnerships Team /pkg/infra/httpclient/httpclientprovider/sigv4_middleware.go @grafana/grafana-partnerships-team +/pkg/infra/httpclient/httpclientprovider/sigv4_middleware_test.go @grafana/grafana-partnerships-team # Kind system and code generation embed.go @grafana/grafana-as-code diff --git a/pkg/api/datasources.go b/pkg/api/datasources.go index 45458f36a26..7e2ff3f32ec 100644 --- a/pkg/api/datasources.go +++ b/pkg/api/datasources.go @@ -24,7 +24,6 @@ import ( "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" - "github.com/grafana/grafana/pkg/util/proxyutil" "github.com/grafana/grafana/pkg/web" ) @@ -823,21 +822,6 @@ func (hs *HTTPServer) checkDatasourceHealth(c *models.ReqContext, ds *datasource return response.Error(http.StatusForbidden, "Access denied", err) } - if hs.DataProxy.OAuthTokenService.IsOAuthPassThruEnabled(ds) { - if token := hs.DataProxy.OAuthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser); token != nil { - req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) - idToken, ok := token.Extra("id_token").(string) - if ok && idToken != "" { - req.Headers["X-ID-Token"] = idToken - } - } - } - - proxyutil.ClearCookieHeader(c.Req, ds.AllowedCookies(), []string{hs.Cfg.LoginCookieName}) - if cookieStr := c.Req.Header.Get("Cookie"); cookieStr != "" { - req.Headers["Cookie"] = cookieStr - } - resp, err := hs.pluginClient.CheckHealth(c.Req.Context(), req) if err != nil { return translatePluginRequestErrorToAPIError(err) diff --git a/pkg/api/dtos/models.go b/pkg/api/dtos/models.go index 9fa6eb79667..2b65fa0e4b5 100644 --- a/pkg/api/dtos/models.go +++ b/pkg/api/dtos/models.go @@ -3,7 +3,6 @@ package dtos import ( "crypto/md5" "fmt" - "net/http" "regexp" "strings" @@ -73,8 +72,6 @@ type MetricRequest struct { Debug bool `json:"debug"` PublicDashboardAccessToken string `json:"publicDashboardAccessToken"` - - HTTPRequest *http.Request `json:"-"` } func (mr *MetricRequest) GetUniqueDatasourceTypes() []string { @@ -98,11 +95,10 @@ func (mr *MetricRequest) GetUniqueDatasourceTypes() []string { func (mr *MetricRequest) CloneWithQueries(queries []*simplejson.Json) MetricRequest { return MetricRequest{ - From: mr.From, - To: mr.To, - Queries: queries, - Debug: mr.Debug, - HTTPRequest: mr.HTTPRequest, + From: mr.From, + To: mr.To, + Queries: queries, + Debug: mr.Debug, } } diff --git a/pkg/api/metrics.go b/pkg/api/metrics.go index 07a59ab1911..6f74a46a3f9 100644 --- a/pkg/api/metrics.go +++ b/pkg/api/metrics.go @@ -52,8 +52,6 @@ func (hs *HTTPServer) QueryMetricsV2(c *models.ReqContext) response.Response { return response.Error(http.StatusBadRequest, "bad request data", err) } - reqDTO.HTTPRequest = c.Req - resp, err := hs.queryDataService.QueryData(c.Req.Context(), c.SignedInUser, c.SkipCache, reqDTO) if err != nil { return hs.handleQueryMetricsError(err) diff --git a/pkg/api/metrics_test.go b/pkg/api/metrics_test.go index 3992a5cac1d..c21e177bf95 100644 --- a/pkg/api/metrics_test.go +++ b/pkg/api/metrics_test.go @@ -13,15 +13,12 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" - "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/backendplugin" "github.com/grafana/grafana/pkg/plugins/config" pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client" "github.com/grafana/grafana/pkg/plugins/manager/registry" - "github.com/grafana/grafana/pkg/services/datasources" fakeDatasources "github.com/grafana/grafana/pkg/services/datasources/fakes" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" @@ -46,31 +43,6 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request) return rv.err } -type fakeOAuthTokenService struct { - passThruEnabled bool - token *oauth2.Token -} - -func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { - return ts.token -} - -func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool { - return ts.passThruEnabled -} - -func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { - return nil, false, nil -} - -func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { - return nil -} - -func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { - return nil -} - // `/ds/query` endpoint test func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) { qds := query.ProvideService( @@ -89,7 +61,6 @@ func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) { return &backend.QueryDataResponse{Responses: resp}, nil }, }, - &fakeOAuthTokenService{}, ) serverFeatureEnabled := SetupAPITestServer(t, func(hs *HTTPServer) { hs.queryDataService = qds @@ -138,7 +109,6 @@ func TestAPIEndpoint_Metrics_PluginDecryptionFailure(t *testing.T) { return &backend.QueryDataResponse{Responses: resp}, nil }, }, - &fakeOAuthTokenService{}, ) httpServer := SetupAPITestServer(t, func(hs *HTTPServer) { hs.queryDataService = qds @@ -292,7 +262,6 @@ func TestDataSourceQueryError(t *testing.T) { &fakePluginRequestValidator{}, &fakeDatasources.FakeDataSourceService{}, pluginClient.ProvideService(r, &config.Cfg{}), - &fakeOAuthTokenService{}, ) hs.QuotaService = quotatest.New(false, nil) }) diff --git a/pkg/api/plugin_resource.go b/pkg/api/plugin_resource.go index c73941a087b..970c4d6a4bc 100644 --- a/pkg/api/plugin_resource.go +++ b/pkg/api/plugin_resource.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -15,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins/backendplugin" - "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/util/proxyutil" "github.com/grafana/grafana/pkg/web" @@ -78,17 +76,6 @@ func (hs *HTTPServer) callPluginResourceWithDataSource(c *models.ReqContext, plu return } - if hs.DataProxy.OAuthTokenService.IsOAuthPassThruEnabled(ds) { - if token := hs.DataProxy.OAuthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser); token != nil { - req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken)) - - idToken, ok := token.Extra("id_token").(string) - if ok && idToken != "" { - req.Header.Add("X-ID-Token", idToken) - } - } - } - if err = hs.makePluginResourceRequest(c.Resp, req, pCtx); err != nil { handleCallResourceError(err, c) } @@ -110,24 +97,6 @@ func (hs *HTTPServer) pluginResourceRequest(c *models.ReqContext) (*http.Request } func (hs *HTTPServer) makePluginResourceRequest(w http.ResponseWriter, req *http.Request, pCtx backend.PluginContext) error { - keepCookieModel := struct { - KeepCookies []string `json:"keepCookies"` - }{} - if dis := pCtx.DataSourceInstanceSettings; dis != nil { - err := json.Unmarshal(dis.JSONData, &keepCookieModel) - if err != nil { - hs.log.Warn("failed to unpack JSONData in datasource instance settings", "err", err) - } - } - - list := contexthandler.AuthHTTPHeaderListFromContext(req.Context()) - if list != nil { - for _, name := range list.Items { - req.Header.Del(name) - } - } - - proxyutil.ClearCookieHeader(req, keepCookieModel.KeepCookies, []string{hs.Cfg.LoginCookieName}) proxyutil.PrepareProxyRequest(req) body, err := io.ReadAll(req.Body) diff --git a/pkg/api/plugins_test.go b/pkg/api/plugins_test.go index 5dad8f06d38..e579996bc41 100644 --- a/pkg/api/plugins_test.go +++ b/pkg/api/plugins_test.go @@ -23,7 +23,6 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" ac "github.com/grafana/grafana/pkg/services/accesscontrol" - "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/pluginsettings" "github.com/grafana/grafana/pkg/services/quota/quotatest" @@ -315,11 +314,6 @@ func TestMakePluginResourceRequest(t *testing.T) { } req := httptest.NewRequest(http.MethodGet, "/", nil) - const customHeader = "X-CUSTOM" - req.Header.Set(customHeader, "val") - ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) - req = req.WithContext(ctx) - resp := httptest.NewRecorder() pCtx := backend.PluginContext{} err := hs.makePluginResourceRequest(resp, req, pCtx) @@ -333,7 +327,6 @@ func TestMakePluginResourceRequest(t *testing.T) { require.Equal(t, resp.Header().Get("Content-Type"), "application/json") require.Equal(t, "sandbox", resp.Header().Get("Content-Security-Policy")) - require.Empty(t, req.Header.Get(customHeader)) } func TestMakePluginResourceRequestSetCookieNotPresent(t *testing.T) { diff --git a/pkg/cmd/grafana-cli/runner/wire.go b/pkg/cmd/grafana-cli/runner/wire.go index 8a40ca839a9..25d27ff9e8d 100644 --- a/pkg/cmd/grafana-cli/runner/wire.go +++ b/pkg/cmd/grafana-cli/runner/wire.go @@ -34,18 +34,7 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/csrf" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/plugins" - "github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin" - pluginsCfg "github.com/grafana/grafana/pkg/plugins/config" - "github.com/grafana/grafana/pkg/plugins/manager" - "github.com/grafana/grafana/pkg/plugins/manager/client" pluginDashboards "github.com/grafana/grafana/pkg/plugins/manager/dashboards" - "github.com/grafana/grafana/pkg/plugins/manager/loader" - processManager "github.com/grafana/grafana/pkg/plugins/manager/process" - "github.com/grafana/grafana/pkg/plugins/manager/registry" - managerStore "github.com/grafana/grafana/pkg/plugins/manager/store" - "github.com/grafana/grafana/pkg/plugins/plugincontext" - "github.com/grafana/grafana/pkg/plugins/repo" "github.com/grafana/grafana/pkg/registry/corekind" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" @@ -94,6 +83,7 @@ import ( plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service" "github.com/grafana/grafana/pkg/services/pluginsettings" pluginSettings "github.com/grafana/grafana/pkg/services/pluginsettings/service" + "github.com/grafana/grafana/pkg/services/pluginsintegration" "github.com/grafana/grafana/pkg/services/preference/prefimpl" "github.com/grafana/grafana/pkg/services/publicdashboards" publicdashboardsApi "github.com/grafana/grafana/pkg/services/publicdashboards/api" @@ -181,28 +171,9 @@ var wireSet = wire.NewSet( updatechecker.ProvideGrafanaService, updatechecker.ProvidePluginsService, uss.ProvideService, - pluginsCfg.ProvideConfig, - registry.ProvideService, - wire.Bind(new(registry.Service), new(*registry.InMemory)), - repo.ProvideService, - wire.Bind(new(repo.Service), new(*repo.Manager)), - manager.ProvideInstaller, - wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)), - client.ProvideService, - wire.Bind(new(plugins.Client), new(*client.Service)), - managerStore.ProvideService, - wire.Bind(new(plugins.Store), new(*managerStore.Service)), - wire.Bind(new(plugins.RendererManager), new(*managerStore.Service)), - wire.Bind(new(plugins.SecretsPluginManager), new(*managerStore.Service)), - wire.Bind(new(plugins.StaticRouteResolver), new(*managerStore.Service)), + pluginsintegration.WireSet, pluginDashboards.ProvideFileStoreManager, wire.Bind(new(pluginDashboards.FileStore), new(*pluginDashboards.FileStoreManager)), - processManager.ProvideService, - wire.Bind(new(processManager.Service), new(*processManager.Manager)), - coreplugin.ProvideCoreRegistry, - loader.ProvideService, - wire.Bind(new(loader.Service), new(*loader.Loader)), - wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)), cloudwatch.ProvideService, cloudmonitoring.ProvideService, azuremonitor.ProvideService, @@ -236,7 +207,6 @@ var wireSet = wire.NewSet( export.ProvideService, live.ProvideService, pushhttp.ProvideService, - plugincontext.ProvideService, contexthandler.ProvideService, jwt.ProvideService, wire.Bind(new(models.JWTService), new(*jwt.AuthService)), diff --git a/pkg/cmd/grafana-cli/runner/wireexts_oss.go b/pkg/cmd/grafana-cli/runner/wireexts_oss.go index dfc36b3cd26..930c1fcf31f 100644 --- a/pkg/cmd/grafana-cli/runner/wireexts_oss.go +++ b/pkg/cmd/grafana-cli/runner/wireexts_oss.go @@ -7,9 +7,6 @@ import ( "github.com/google/wire" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/plugins" - "github.com/grafana/grafana/pkg/plugins/backendplugin/provider" - "github.com/grafana/grafana/pkg/plugins/manager/signature" "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/server/backgroundsvcs" "github.com/grafana/grafana/pkg/server/usagestatssvcs" @@ -29,6 +26,7 @@ import ( "github.com/grafana/grafana/pkg/services/licensing" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login/authinfoservice" + "github.com/grafana/grafana/pkg/services/pluginsintegration" "github.com/grafana/grafana/pkg/services/provisioning" "github.com/grafana/grafana/pkg/services/searchusers" "github.com/grafana/grafana/pkg/services/searchusers/filters" @@ -71,10 +69,6 @@ var wireExtsSet = wire.NewSet( wire.Bind(new(user.SearchUserFilter), new(*filters.OSSSearchUserFilter)), searchusers.ProvideUsersService, wire.Bind(new(searchusers.Service), new(*searchusers.OSSService)), - signature.ProvideOSSAuthorizer, - wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)), - provider.ProvideService, - wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)), ldap.ProvideGroupsService, wire.Bind(new(ldap.Groups), new(*ldap.OSSGroups)), permissions.ProvideDatasourcePermissionsService, @@ -85,4 +79,5 @@ var wireExtsSet = wire.NewSet( wire.Bind(new(accesscontrol.DatasourcePermissionsService), new(*ossaccesscontrol.DatasourcePermissionsService)), encryptionprovider.ProvideEncryptionProvider, wire.Bind(new(encryption.Provider), new(encryptionprovider.Provider)), + pluginsintegration.WireExtensionSet, ) diff --git a/pkg/expr/graph.go b/pkg/expr/graph.go index deb915e9f81..9e78e5e2f2e 100644 --- a/pkg/expr/graph.go +++ b/pkg/expr/graph.go @@ -63,6 +63,10 @@ func (dp *DataPipeline) execute(c context.Context, now time.Time, s *Service) (m // BuildPipeline builds a graph of the nodes, and returns the nodes in an // executable order. func (s *Service) buildPipeline(req *Request) (DataPipeline, error) { + if req != nil && len(req.Headers) == 0 { + req.Headers = map[string]string{} + } + graph, err := s.buildDependencyGraph(req) if err != nil { return nil, err @@ -144,12 +148,11 @@ func (s *Service) buildGraph(req *Request) (*simple.DirectedGraph, error) { } rn := &rawNode{ - Query: rawQueryProp, - RefID: query.RefID, - TimeRange: query.TimeRange, - QueryType: query.QueryType, - DataSource: query.DataSource, - QueryEnricher: query.QueryEnricher, + Query: rawQueryProp, + RefID: query.RefID, + TimeRange: query.TimeRange, + QueryType: query.QueryType, + DataSource: query.DataSource, } var node Node diff --git a/pkg/expr/nodes.go b/pkg/expr/nodes.go index f679ec38354..8a4fdf69863 100644 --- a/pkg/expr/nodes.go +++ b/pkg/expr/nodes.go @@ -42,12 +42,11 @@ type baseNode struct { } type rawNode struct { - RefID string `json:"refId"` - Query map[string]interface{} - QueryType string - TimeRange TimeRange - DataSource *datasources.DataSource - QueryEnricher QueryDataRequestEnricher + RefID string `json:"refId"` + Query map[string]interface{} + QueryType string + TimeRange TimeRange + DataSource *datasources.DataSource } func (rn *rawNode) GetCommandType() (c CommandType, err error) { @@ -140,9 +139,8 @@ const ( // DSNode is a DPNode that holds a datasource request. type DSNode struct { baseNode - query json.RawMessage - datasource *datasources.DataSource - queryEnricher QueryDataRequestEnricher + query json.RawMessage + datasource *datasources.DataSource orgID int64 queryType string @@ -171,15 +169,14 @@ func (s *Service) buildDSNode(dp *simple.DirectedGraph, rn *rawNode, req *Reques id: dp.NewNode().ID(), refID: rn.RefID, }, - orgID: req.OrgId, - query: json.RawMessage(encodedQuery), - queryType: rn.QueryType, - intervalMS: defaultIntervalMS, - maxDP: defaultMaxDP, - timeRange: rn.TimeRange, - request: *req, - datasource: rn.DataSource, - queryEnricher: rn.QueryEnricher, + orgID: req.OrgId, + query: json.RawMessage(encodedQuery), + queryType: rn.QueryType, + intervalMS: defaultIntervalMS, + maxDP: defaultMaxDP, + timeRange: rn.TimeRange, + request: *req, + datasource: rn.DataSource, } var floatIntervalMS float64 @@ -232,10 +229,6 @@ func (dn *DSNode) Execute(ctx context.Context, now time.Time, _ mathexp.Vars, s Headers: dn.request.Headers, } - if dn.queryEnricher != nil { - ctx = dn.queryEnricher(ctx, req) - } - resp, err := s.dataService.QueryData(ctx, req) if err != nil { return mathexp.Results{}, err diff --git a/pkg/expr/transform.go b/pkg/expr/transform.go index 5d5351b04ef..b757ea7621f 100644 --- a/pkg/expr/transform.go +++ b/pkg/expr/transform.go @@ -38,16 +38,12 @@ type Request struct { User *backend.User } -// QueryDataRequestEnricher function definition for enriching a backend.QueryDataRequest request. -type QueryDataRequestEnricher func(ctx context.Context, req *backend.QueryDataRequest) context.Context - // Query is like plugins.DataSubQuery, but with a a time range, and only the UID // for the data source. Also interval is a time.Duration. type Query struct { RefID string TimeRange TimeRange DataSource *datasources.DataSource `json:"datasource"` - QueryEnricher QueryDataRequestEnricher JSON json.RawMessage Interval time.Duration QueryType string diff --git a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go new file mode 100644 index 00000000000..3bc3bcdc824 --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware.go @@ -0,0 +1,27 @@ +package httpclientprovider + +import ( + "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" +) + +const DeleteHeadersMiddlewareName = "delete-headers" + +// DeleteHeadersMiddleware middleware that delete headers on the outgoing +// request if header names provided. +func DeleteHeadersMiddleware(headerNames ...string) httpclient.Middleware { + return httpclient.NamedMiddlewareFunc(DeleteHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { + if len(headerNames) == 0 { + return next + } + + return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for _, k := range headerNames { + req.Header.Del(k) + } + + return next.RoundTrip(req) + }) + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go new file mode 100644 index 00000000000..01ec6a0b08e --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/delete_headers_middleware_test.go @@ -0,0 +1,66 @@ +package httpclientprovider + +import ( + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/stretchr/testify/require" +) + +func TestDeleteHeadersMiddleware(t *testing.T) { + t.Run("Without headerNames should return next http.RoundTripper", func(t *testing.T) { + ctx := &testContext{} + finalRoundTripper := ctx.createRoundTripper("finalrt") + var headerNames []string + mw := DeleteHeadersMiddleware(headerNames...) + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + require.Len(t, ctx.callChain, 1) + require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain) + }) + + t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) { + ctx := &testContext{} + finalRoundTripper := ctx.createRoundTripper("final") + headerNames := []string{"X-Header-B", "X-Header-C"} + mw := DeleteHeadersMiddleware(headerNames...) + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://", nil) + require.NoError(t, err) + req.Header.Set("X-Header-A", "a") + req.Header.Set("X-Header-B", "b") + req.Header.Set("X-Header-C", "c") + req.Header.Set("X-Header-D", "d") + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + require.Len(t, ctx.callChain, 1) + require.ElementsMatch(t, []string{"final"}, ctx.callChain) + + require.Equal(t, "a", req.Header.Get("X-Header-A")) + require.Empty(t, req.Header.Get("X-Header-B")) + require.Empty(t, req.Header.Get("X-Header-C")) + require.Equal(t, "d", req.Header.Get("X-Header-D")) + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go index c9eb1767e01..711952dea1d 100644 --- a/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go +++ b/pkg/infra/httpclient/httpclientprovider/forwarded_cookie_middleware_test.go @@ -70,3 +70,16 @@ func TestForwardedCookiesMiddleware(t *testing.T) { }) } } + +type testContext struct { + callChain []string + req *http.Request +} + +func (c *testContext) createRoundTripper() http.RoundTripper { + return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + c.callChain = append(c.callChain, "final") + c.req = req + return &http.Response{StatusCode: http.StatusOK}, nil + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware.go b/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware.go deleted file mode 100644 index d5c78e794d7..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware.go +++ /dev/null @@ -1,31 +0,0 @@ -package httpclientprovider - -import ( - "fmt" - "net/http" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "golang.org/x/oauth2" -) - -const ForwardedOAuthIdentityMiddlewareName = "forwarded-oauth-identity" - -// ForwardedOAuthIdentityMiddleware middleware that sets Authorization/X-ID-Token -// headers on the outgoing request if an OAuth Token is provided -func ForwardedOAuthIdentityMiddleware(token *oauth2.Token) httpclient.Middleware { - return httpclient.NamedMiddlewareFunc(ForwardedOAuthIdentityMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { - if token == nil { - return next - } - return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken)) - - idToken, ok := token.Extra("id_token").(string) - if ok && idToken != "" { - req.Header.Set("X-ID-Token", idToken) - } - - return next.RoundTrip(req) - }) - }) -} diff --git a/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware_test.go deleted file mode 100644 index 1cd65d1bf09..00000000000 --- a/pkg/infra/httpclient/httpclientprovider/forwarded_oauth_identity_middleware_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package httpclientprovider_test - -import ( - "net/http" - "testing" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func TestForwardedOAuthIdentityMiddleware(t *testing.T) { - at := &oauth2.Token{ - AccessToken: "access-token", - } - tcs := []struct { - desc string - token *oauth2.Token - expectedAuthorizationHeader string - expectedIDTokenHeader string - }{ - { - desc: "With nil token should not populate Cookie headers", - token: nil, - expectedAuthorizationHeader: "", - expectedIDTokenHeader: "", - }, - { - desc: "With access token set should populate Authorization header", - token: at, - expectedAuthorizationHeader: "Bearer access-token", - expectedIDTokenHeader: "", - }, - { - desc: "With Authorization and X-ID-Token header set should populate Authorization and X-Id-Token header", - token: at.WithExtra(map[string]interface{}{"id_token": "id-token"}), - expectedAuthorizationHeader: "Bearer access-token", - expectedIDTokenHeader: "id-token", - }, - } - - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - ctx := &testContext{} - finalRoundTripper := ctx.createRoundTripper() - mw := httpclientprovider.ForwardedOAuthIdentityMiddleware(tc.token) - opts := httpclient.Options{} - rt := mw.CreateMiddleware(opts, finalRoundTripper) - require.NotNil(t, rt) - middlewareName, ok := mw.(httpclient.MiddlewareName) - require.True(t, ok) - require.Equal(t, "forwarded-oauth-identity", middlewareName.MiddlewareName()) - - req, err := http.NewRequest(http.MethodGet, "http://", nil) - require.NoError(t, err) - res, err := rt.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, res) - if res.Body != nil { - require.NoError(t, res.Body.Close()) - } - require.Len(t, ctx.callChain, 1) - require.ElementsMatch(t, []string{"final"}, ctx.callChain) - require.Equal(t, tc.expectedAuthorizationHeader, ctx.req.Header.Get("Authorization")) - require.Equal(t, tc.expectedIDTokenHeader, ctx.req.Header.Get("X-ID-Token")) - }) - } -} - -type testContext struct { - callChain []string - req *http.Request -} - -func (c *testContext) createRoundTripper() http.RoundTripper { - return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - c.callChain = append(c.callChain, "final") - c.req = req - return &http.Response{StatusCode: http.StatusOK}, nil - }) -} diff --git a/pkg/infra/httpclient/httpclientprovider/http_client_provider.go b/pkg/infra/httpclient/httpclientprovider/http_client_provider.go index 054f045471d..da9eb67f966 100644 --- a/pkg/infra/httpclient/httpclientprovider/http_client_provider.go +++ b/pkg/infra/httpclient/httpclientprovider/http_client_provider.go @@ -25,10 +25,10 @@ func New(cfg *setting.Cfg, validator models.PluginRequestValidator, tracer traci middlewares := []sdkhttpclient.Middleware{ TracingMiddleware(logger, tracer), DataSourceMetricsMiddleware(), + sdkhttpclient.ContextualMiddleware(), SetUserAgentMiddleware(userAgent), sdkhttpclient.BasicAuthenticationMiddleware(), sdkhttpclient.CustomHeadersMiddleware(), - sdkhttpclient.ContextualMiddleware(), ResponseLimitMiddleware(cfg.ResponseLimit), RedirectLimitMiddleware(validator), } diff --git a/pkg/infra/httpclient/httpclientprovider/http_client_provider_test.go b/pkg/infra/httpclient/httpclientprovider/http_client_provider_test.go index 5f4955f7738..208c51fff21 100644 --- a/pkg/infra/httpclient/httpclientprovider/http_client_provider_test.go +++ b/pkg/infra/httpclient/httpclientprovider/http_client_provider_test.go @@ -29,10 +29,10 @@ func TestHTTPClientProvider(t *testing.T) { require.Len(t, o.Middlewares, 8) require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName()) }) @@ -53,10 +53,10 @@ func TestHTTPClientProvider(t *testing.T) { require.Len(t, o.Middlewares, 9) require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, SigV4MiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName()) }) @@ -78,10 +78,10 @@ func TestHTTPClientProvider(t *testing.T) { require.Len(t, o.Middlewares, 9) require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, HostRedirectValidationMiddlewareName, o.Middlewares[7].(sdkhttpclient.MiddlewareName).MiddlewareName()) require.Equal(t, HTTPLoggerMiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName()) diff --git a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go new file mode 100644 index 00000000000..787cada75af --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware.go @@ -0,0 +1,31 @@ +package httpclientprovider + +import ( + "net/http" + "net/textproto" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" +) + +const SetHeadersMiddlewareName = "set-headers" + +// SetHeadersMiddleware middleware that sets headers on the outgoing +// request if headers provided. +// If the request already contains any of the headers provided, they +// will be overwritten. +func SetHeadersMiddleware(headers http.Header) httpclient.Middleware { + return httpclient.NamedMiddlewareFunc(SetHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { + if len(headers) == 0 { + return next + } + + return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for k, v := range headers { + canonicalKey := textproto.CanonicalMIMEHeaderKey(k) + req.Header[canonicalKey] = v + } + + return next.RoundTrip(req) + }) + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go new file mode 100644 index 00000000000..0970d23cbcc --- /dev/null +++ b/pkg/infra/httpclient/httpclientprovider/set_headers_middleware_test.go @@ -0,0 +1,66 @@ +package httpclientprovider + +import ( + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/stretchr/testify/require" +) + +func TestSetHeadersMiddleware(t *testing.T) { + t.Run("Without headers set should return next http.RoundTripper", func(t *testing.T) { + ctx := &testContext{} + finalRoundTripper := ctx.createRoundTripper("finalrt") + var headers http.Header + mw := SetHeadersMiddleware(headers) + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + require.Len(t, ctx.callChain, 1) + require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain) + }) + + t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) { + ctx := &testContext{} + finalRoundTripper := ctx.createRoundTripper("final") + headers := http.Header{ + "X-Header-A": []string{"value a"}, + "X-Header-B": []string{"value b"}, + "X-HEader-C": []string{"value c"}, + } + mw := SetHeadersMiddleware(headers) + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://", nil) + require.NoError(t, err) + req.Header.Set("X-Header-B", "d") + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + require.Len(t, ctx.callChain, 1) + require.ElementsMatch(t, []string{"final"}, ctx.callChain) + + require.Equal(t, "value a", req.Header.Get("X-Header-A")) + require.Equal(t, "value b", req.Header.Get("X-Header-B")) + require.Equal(t, "value c", req.Header.Get("X-Header-C")) + }) +} diff --git a/pkg/infra/httpclient/httpclientprovider/user_agent_middleware_test.go b/pkg/infra/httpclient/httpclientprovider/user_agent_middleware_test.go index 9aee212c06b..f4af09c800d 100644 --- a/pkg/infra/httpclient/httpclientprovider/user_agent_middleware_test.go +++ b/pkg/infra/httpclient/httpclientprovider/user_agent_middleware_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestCustomHeadersMiddleware(t *testing.T) { +func TestSetUserAgentMiddleware(t *testing.T) { t.Run("Without user agent set should return next http.RoundTripper", func(t *testing.T) { ctx := &testContext{} finalRoundTripper := ctx.createRoundTripper("finalrt") diff --git a/pkg/plugins/ifaces.go b/pkg/plugins/ifaces.go index da83817d3ba..3f97558d40c 100644 --- a/pkg/plugins/ifaces.go +++ b/pkg/plugins/ifaces.go @@ -79,3 +79,20 @@ type PluginLoaderAuthorizer interface { type RoleRegistry interface { DeclarePluginRoles(ctx context.Context, ID, name string, registrations []RoleRegistration) error } + +// ClientMiddleware is an interface representing the ability to create a middleware +// that implements the Client interface. +type ClientMiddleware interface { + // CreateClientMiddleware creates a new client middleware. + CreateClientMiddleware(next Client) Client +} + +// The ClientMiddlewareFunc type is an adapter to allow the use of ordinary +// functions as ClientMiddleware's. If f is a function with the appropriate +// signature, ClientMiddlewareFunc(f) is a ClientMiddleware that calls f. +type ClientMiddlewareFunc func(next Client) Client + +// CreateClientMiddleware implements the ClientMiddleware interface. +func (fn ClientMiddlewareFunc) CreateClientMiddleware(next Client) Client { + return fn(next) +} diff --git a/pkg/plugins/manager/client/client.go b/pkg/plugins/manager/client/client.go index fe7d67ba497..0adb6e0dcb4 100644 --- a/pkg/plugins/manager/client/client.go +++ b/pkg/plugins/manager/client/client.go @@ -29,6 +29,10 @@ func ProvideService(pluginRegistry registry.Service, cfg *config.Cfg) *Service { } func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + plugin, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return nil, plugins.ErrPluginNotRegistered.Errorf("%w", backendplugin.ErrPluginNotRegistered) @@ -65,6 +69,14 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) } func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return fmt.Errorf("req cannot be nil") + } + + if sender == nil { + return fmt.Errorf("sender cannot be nil") + } + p, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return backendplugin.ErrPluginNotRegistered @@ -84,6 +96,10 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq } func (s *Service) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + p, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return nil, backendplugin.ErrPluginNotRegistered @@ -102,6 +118,10 @@ func (s *Service) CollectMetrics(ctx context.Context, req *backend.CollectMetric } func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + p, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return nil, backendplugin.ErrPluginNotRegistered @@ -129,6 +149,10 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque } func (s *Service) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + plugin, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return nil, backendplugin.ErrPluginNotRegistered @@ -138,6 +162,10 @@ func (s *Service) SubscribeStream(ctx context.Context, req *backend.SubscribeStr } func (s *Service) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + plugin, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return nil, backendplugin.ErrPluginNotRegistered @@ -147,6 +175,14 @@ func (s *Service) PublishStream(ctx context.Context, req *backend.PublishStreamR } func (s *Service) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return fmt.Errorf("req cannot be nil") + } + + if sender == nil { + return fmt.Errorf("sender cannot be nil") + } + plugin, exists := s.plugin(ctx, req.PluginContext.PluginID) if !exists { return backendplugin.ErrPluginNotRegistered diff --git a/pkg/plugins/manager/client/clienttest/clienttest.go b/pkg/plugins/manager/client/clienttest/clienttest.go new file mode 100644 index 00000000000..f88d07d3376 --- /dev/null +++ b/pkg/plugins/manager/client/clienttest/clienttest.go @@ -0,0 +1,205 @@ +package clienttest + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/plugins/manager/client" + "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/web" + "github.com/stretchr/testify/require" +) + +type TestClient struct { + plugins.Client + QueryDataFunc backend.QueryDataHandlerFunc + CallResourceFunc backend.CallResourceHandlerFunc + CheckHealthFunc backend.CheckHealthHandlerFunc +} + +func (c *TestClient) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if c.QueryDataFunc != nil { + return c.QueryDataFunc(ctx, req) + } + + return nil, nil +} + +func (c *TestClient) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if c.CallResourceFunc != nil { + return c.CallResourceFunc(ctx, req, sender) + } + + return nil +} + +func (c *TestClient) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if c.CheckHealthFunc != nil { + return c.CheckHealthFunc(ctx, req) + } + + return nil, nil +} + +type MiddlewareScenarioContext struct { + QueryDataCallChain []string + CallResourceCallChain []string + CollectMetricsCallChain []string + CheckHealthCallChain []string + SubscribeStreamCallChain []string + PublishStreamCallChain []string + RunStreamCallChain []string +} + +func (ctx *MiddlewareScenarioContext) NewMiddleware(name string) plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &TestMiddleware{ + next: next, + Name: name, + sCtx: ctx, + } + }) +} + +type TestMiddleware struct { + next plugins.Client + sCtx *MiddlewareScenarioContext + Name string +} + +func (m *TestMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.QueryData(ctx, req) + m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("before %s", m.Name)) + err := m.next.CallResource(ctx, req, sender) + m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("after %s", m.Name)) + return err +} + +func (m *TestMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.CollectMetrics(ctx, req) + m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.CheckHealth(ctx, req) + m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.SubscribeStream(ctx, req) + m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.PublishStream(ctx, req) + m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("before %s", m.Name)) + err := m.next.RunStream(ctx, req, sender) + m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return err +} + +var _ plugins.Client = &TestClient{} + +type ClientDecoratorTest struct { + T *testing.T + Context context.Context + TestClient *TestClient + Middlewares []plugins.ClientMiddleware + Decorator *client.Decorator + ReqContext *models.ReqContext + QueryDataReq *backend.QueryDataRequest + QueryDataCtx context.Context + CallResourceReq *backend.CallResourceRequest + CallResourceCtx context.Context + CheckHealthReq *backend.CheckHealthRequest + CheckHealthCtx context.Context +} + +type ClientDecoratorTestOption func(*ClientDecoratorTest) + +func NewClientDecoratorTest(t *testing.T, opts ...ClientDecoratorTestOption) *ClientDecoratorTest { + cdt := &ClientDecoratorTest{ + T: t, + Context: context.Background(), + } + cdt.TestClient = &TestClient{ + QueryDataFunc: func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + cdt.QueryDataReq = req + cdt.QueryDataCtx = ctx + return nil, nil + }, + CallResourceFunc: func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + cdt.CallResourceReq = req + cdt.CallResourceCtx = ctx + return nil + }, + CheckHealthFunc: func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + cdt.CheckHealthReq = req + cdt.CheckHealthCtx = ctx + return nil, nil + }, + } + require.NotNil(t, cdt) + + for _, opt := range opts { + opt(cdt) + } + + d, err := client.NewDecorator(cdt.TestClient, cdt.Middlewares...) + require.NoError(t, err) + require.NotNil(t, d) + + cdt.Decorator = d + + return cdt +} + +func WithReqContext(req *http.Request, user *user.SignedInUser) ClientDecoratorTestOption { + return ClientDecoratorTestOption(func(cdt *ClientDecoratorTest) { + if cdt.ReqContext == nil { + cdt.ReqContext = &models.ReqContext{ + Context: &web.Context{}, + SignedInUser: user, + } + } + + cdt.Context = ctxkey.Set(cdt.Context, cdt.ReqContext) + + *req = *req.WithContext(cdt.Context) + cdt.ReqContext.Req = req + }) +} + +func WithMiddlewares(middlewares ...plugins.ClientMiddleware) ClientDecoratorTestOption { + return ClientDecoratorTestOption(func(cdt *ClientDecoratorTest) { + if cdt.Middlewares == nil { + cdt.Middlewares = []plugins.ClientMiddleware{} + } + + cdt.Middlewares = append(cdt.Middlewares, middlewares...) + }) +} diff --git a/pkg/plugins/manager/client/decorator.go b/pkg/plugins/manager/client/decorator.go new file mode 100644 index 00000000000..d38f87c03d0 --- /dev/null +++ b/pkg/plugins/manager/client/decorator.go @@ -0,0 +1,126 @@ +package client + +import ( + "context" + "fmt" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/plugins" +) + +// Decorator allows a plugins.Client to be decorated with middlewares. +type Decorator struct { + client plugins.Client + middlewares []plugins.ClientMiddleware +} + +// NewDecorator creates a new plugins.client decorator. +func NewDecorator(client plugins.Client, middlewares ...plugins.ClientMiddleware) (*Decorator, error) { + if client == nil { + return nil, fmt.Errorf("client cannot be nil") + } + + return &Decorator{ + client: client, + middlewares: middlewares, + }, nil +} + +func (d *Decorator) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.QueryData(ctx, req) +} + +func (d *Decorator) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return fmt.Errorf("req cannot be nil") + } + + if sender == nil { + return fmt.Errorf("sender cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.CallResource(ctx, req, sender) +} + +func (d *Decorator) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.CollectMetrics(ctx, req) +} + +func (d *Decorator) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.CheckHealth(ctx, req) +} + +func (d *Decorator) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.SubscribeStream(ctx, req) +} + +func (d *Decorator) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return nil, fmt.Errorf("req cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.PublishStream(ctx, req) +} + +func (d *Decorator) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return fmt.Errorf("req cannot be nil") + } + + if sender == nil { + return fmt.Errorf("sender cannot be nil") + } + + client := clientFromMiddlewares(d.middlewares, d.client) + return client.RunStream(ctx, req, sender) +} + +func clientFromMiddlewares(middlewares []plugins.ClientMiddleware, finalClient plugins.Client) plugins.Client { + if len(middlewares) == 0 { + return finalClient + } + + reversed := reverseMiddlewares(middlewares) + next := finalClient + + for _, m := range reversed { + next = m.CreateClientMiddleware(next) + } + + return next +} + +func reverseMiddlewares(middlewares []plugins.ClientMiddleware) []plugins.ClientMiddleware { + reversed := make([]plugins.ClientMiddleware, len(middlewares)) + copy(reversed, middlewares) + + for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 { + reversed[i], reversed[j] = reversed[j], reversed[i] + } + + return reversed +} + +var _ plugins.Client = &Decorator{} diff --git a/pkg/plugins/manager/client/decorator_test.go b/pkg/plugins/manager/client/decorator_test.go new file mode 100644 index 00000000000..abe533946cb --- /dev/null +++ b/pkg/plugins/manager/client/decorator_test.go @@ -0,0 +1,229 @@ +package client + +import ( + "context" + "fmt" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/plugins" + "github.com/stretchr/testify/require" +) + +func TestDecorator(t *testing.T) { + var queryDataCalled bool + var callResourceCalled bool + var checkHealthCalled bool + c := &TestClient{ + QueryDataFunc: func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + queryDataCalled = true + return nil, nil + }, + CallResourceFunc: func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + callResourceCalled = true + return nil + }, + CheckHealthFunc: func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + checkHealthCalled = true + return nil, nil + }, + } + require.NotNil(t, c) + + ctx := MiddlewareScenarioContext{} + + mwOne := ctx.NewMiddleware("mw1") + mwTwo := ctx.NewMiddleware("mw2") + + d, err := NewDecorator(c, mwOne, mwTwo) + require.NoError(t, err) + require.NotNil(t, d) + + _, _ = d.QueryData(context.Background(), &backend.QueryDataRequest{}) + require.True(t, queryDataCalled) + + sender := callResourceResponseSenderFunc(func(res *backend.CallResourceResponse) error { + return nil + }) + + _ = d.CallResource(context.Background(), &backend.CallResourceRequest{}, sender) + require.True(t, callResourceCalled) + + _, _ = d.CheckHealth(context.Background(), &backend.CheckHealthRequest{}) + require.True(t, checkHealthCalled) + + require.Len(t, ctx.QueryDataCallChain, 4) + require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.QueryDataCallChain) + require.Len(t, ctx.CallResourceCallChain, 4) + require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.CallResourceCallChain) + require.Len(t, ctx.CheckHealthCallChain, 4) + require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.CheckHealthCallChain) +} + +func TestReverseMiddlewares(t *testing.T) { + t.Run("Should reverse 1 middleware", func(t *testing.T) { + ctx := MiddlewareScenarioContext{} + middlewares := []plugins.ClientMiddleware{ + ctx.NewMiddleware("mw1"), + } + reversed := reverseMiddlewares(middlewares) + require.Len(t, reversed, 1) + require.Equal(t, "mw1", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name) + }) + + t.Run("Should reverse 2 middlewares", func(t *testing.T) { + ctx := MiddlewareScenarioContext{} + middlewares := []plugins.ClientMiddleware{ + ctx.NewMiddleware("mw1"), + ctx.NewMiddleware("mw2"), + } + reversed := reverseMiddlewares(middlewares) + require.Len(t, reversed, 2) + require.Equal(t, "mw2", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw1", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name) + }) + + t.Run("Should reverse 3 middlewares", func(t *testing.T) { + ctx := MiddlewareScenarioContext{} + middlewares := []plugins.ClientMiddleware{ + ctx.NewMiddleware("mw1"), + ctx.NewMiddleware("mw2"), + ctx.NewMiddleware("mw3"), + } + reversed := reverseMiddlewares(middlewares) + require.Len(t, reversed, 3) + require.Equal(t, "mw3", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw2", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw1", reversed[2].CreateClientMiddleware(nil).(*TestMiddleware).Name) + }) + + t.Run("Should reverse 4 middlewares", func(t *testing.T) { + ctx := MiddlewareScenarioContext{} + middlewares := []plugins.ClientMiddleware{ + ctx.NewMiddleware("mw1"), + ctx.NewMiddleware("mw2"), + ctx.NewMiddleware("mw3"), + ctx.NewMiddleware("mw4"), + } + reversed := reverseMiddlewares(middlewares) + require.Len(t, reversed, 4) + require.Equal(t, "mw4", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw3", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw2", reversed[2].CreateClientMiddleware(nil).(*TestMiddleware).Name) + require.Equal(t, "mw1", reversed[3].CreateClientMiddleware(nil).(*TestMiddleware).Name) + }) +} + +type TestClient struct { + plugins.Client + QueryDataFunc backend.QueryDataHandlerFunc + CallResourceFunc backend.CallResourceHandlerFunc + CheckHealthFunc backend.CheckHealthHandlerFunc +} + +func (c *TestClient) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if c.QueryDataFunc != nil { + return c.QueryDataFunc(ctx, req) + } + + return nil, nil +} + +func (c *TestClient) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if c.CallResourceFunc != nil { + return c.CallResourceFunc(ctx, req, sender) + } + + return nil +} + +func (c *TestClient) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if c.CheckHealthFunc != nil { + return c.CheckHealthFunc(ctx, req) + } + + return nil, nil +} + +type MiddlewareScenarioContext struct { + QueryDataCallChain []string + CallResourceCallChain []string + CollectMetricsCallChain []string + CheckHealthCallChain []string + SubscribeStreamCallChain []string + PublishStreamCallChain []string + RunStreamCallChain []string +} + +func (ctx *MiddlewareScenarioContext) NewMiddleware(name string) plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &TestMiddleware{ + next: next, + Name: name, + sCtx: ctx, + } + }) +} + +type TestMiddleware struct { + next plugins.Client + sCtx *MiddlewareScenarioContext + Name string +} + +func (m *TestMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.QueryData(ctx, req) + m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("before %s", m.Name)) + err := m.next.CallResource(ctx, req, sender) + m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("after %s", m.Name)) + return err +} + +func (m *TestMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.CollectMetrics(ctx, req) + m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.CheckHealth(ctx, req) + m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.SubscribeStream(ctx, req) + m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("before %s", m.Name)) + res, err := m.next.PublishStream(ctx, req) + m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return res, err +} + +func (m *TestMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("before %s", m.Name)) + err := m.next.RunStream(ctx, req, sender) + m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("after %s", m.Name)) + return err +} + +type callResourceResponseSenderFunc func(res *backend.CallResourceResponse) error + +func (fn callResourceResponseSenderFunc) Send(res *backend.CallResourceResponse) error { + return fn(res) +} + +var _ plugins.Client = &TestClient{} diff --git a/pkg/server/test_env.go b/pkg/server/test_env.go index 2bc16f2aab7..723c3d5659e 100644 --- a/pkg/server/test_env.go +++ b/pkg/server/test_env.go @@ -1,13 +1,32 @@ package server import ( + "github.com/grafana/grafana/pkg/infra/httpclient" + "github.com/grafana/grafana/pkg/plugins/manager/registry" "github.com/grafana/grafana/pkg/services/grpcserver" "github.com/grafana/grafana/pkg/services/notifications" + "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/sqlstore" ) -func ProvideTestEnv(server *Server, store *sqlstore.SQLStore, ns *notifications.NotificationServiceMock, grpcServer grpcserver.Provider) (*TestEnv, error) { - return &TestEnv{server, store, ns, grpcServer}, nil +func ProvideTestEnv( + server *Server, + store *sqlstore.SQLStore, + ns *notifications.NotificationServiceMock, + grpcServer grpcserver.Provider, + pluginRegistry registry.Service, + httpClientProvider httpclient.Provider, + oAuthTokenService *oauthtokentest.Service, +) (*TestEnv, error) { + return &TestEnv{ + server, + store, + ns, + grpcServer, + pluginRegistry, + httpClientProvider, + oAuthTokenService, + }, nil } type TestEnv struct { @@ -15,4 +34,7 @@ type TestEnv struct { SQLStore *sqlstore.SQLStore NotificationService *notifications.NotificationServiceMock GRPCServer grpcserver.Provider + PluginRegistry registry.Service + HTTPClientProvider httpclient.Provider + OAuthTokenService *oauthtokentest.Service } diff --git a/pkg/server/wire.go b/pkg/server/wire.go index bfac09f2aec..5a0bc6f3c72 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -28,18 +28,7 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/csrf" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/plugins" - "github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin" - pluginsCfg "github.com/grafana/grafana/pkg/plugins/config" - "github.com/grafana/grafana/pkg/plugins/manager" - "github.com/grafana/grafana/pkg/plugins/manager/client" pluginDashboards "github.com/grafana/grafana/pkg/plugins/manager/dashboards" - "github.com/grafana/grafana/pkg/plugins/manager/loader" - processManager "github.com/grafana/grafana/pkg/plugins/manager/process" - "github.com/grafana/grafana/pkg/plugins/manager/registry" - managerStore "github.com/grafana/grafana/pkg/plugins/manager/store" - "github.com/grafana/grafana/pkg/plugins/plugincontext" - "github.com/grafana/grafana/pkg/plugins/repo" "github.com/grafana/grafana/pkg/registry/corekind" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" @@ -93,12 +82,14 @@ import ( ngstore "github.com/grafana/grafana/pkg/services/ngalert/store" "github.com/grafana/grafana/pkg/services/notifications" "github.com/grafana/grafana/pkg/services/oauthtoken" + "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/org/orgimpl" "github.com/grafana/grafana/pkg/services/playlist/playlistimpl" "github.com/grafana/grafana/pkg/services/plugindashboards" plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service" "github.com/grafana/grafana/pkg/services/pluginsettings" pluginSettings "github.com/grafana/grafana/pkg/services/pluginsettings/service" + "github.com/grafana/grafana/pkg/services/pluginsintegration" "github.com/grafana/grafana/pkg/services/preference/prefimpl" "github.com/grafana/grafana/pkg/services/publicdashboards" publicdashboardsApi "github.com/grafana/grafana/pkg/services/publicdashboards/api" @@ -192,28 +183,9 @@ var wireBasicSet = wire.NewSet( updatechecker.ProvidePluginsService, uss.ProvideService, wire.Bind(new(usagestats.Service), new(*uss.UsageStats)), - registry.ProvideService, - wire.Bind(new(registry.Service), new(*registry.InMemory)), - pluginsCfg.ProvideConfig, - repo.ProvideService, - wire.Bind(new(repo.Service), new(*repo.Manager)), - manager.ProvideInstaller, - wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)), - client.ProvideService, - wire.Bind(new(plugins.Client), new(*client.Service)), - managerStore.ProvideService, - wire.Bind(new(plugins.Store), new(*managerStore.Service)), - wire.Bind(new(plugins.RendererManager), new(*managerStore.Service)), - wire.Bind(new(plugins.SecretsPluginManager), new(*managerStore.Service)), - wire.Bind(new(plugins.StaticRouteResolver), new(*managerStore.Service)), + pluginsintegration.WireSet, pluginDashboards.ProvideFileStoreManager, wire.Bind(new(pluginDashboards.FileStore), new(*pluginDashboards.FileStoreManager)), - processManager.ProvideService, - wire.Bind(new(processManager.Service), new(*processManager.Manager)), - coreplugin.ProvideCoreRegistry, - loader.ProvideService, - wire.Bind(new(loader.Service), new(*loader.Loader)), - wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)), cloudwatch.ProvideService, cloudmonitoring.ProvideService, azuremonitor.ProvideService, @@ -251,7 +223,6 @@ var wireBasicSet = wire.NewSet( export.ProvideService, live.ProvideService, pushhttp.ProvideService, - plugincontext.ProvideService, contexthandler.ProvideService, jwt.ProvideService, wire.Bind(new(models.JWTService), new(*jwt.AuthService)), @@ -271,8 +242,6 @@ var wireBasicSet = wire.NewSet( social.ProvideService, influxdb.ProvideService, wire.Bind(new(social.Service), new(*social.SocialService)), - oauthtoken.ProvideService, - wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), tempo.ProvideService, loki.ProvideService, graphite.ProvideService, @@ -391,6 +360,8 @@ var wireSet = wire.NewSet( wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), prefimpl.ProvideService, + oauthtoken.ProvideService, + wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), ) var wireTestSet = wire.NewSet( @@ -407,6 +378,9 @@ var wireTestSet = wire.NewSet( wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), prefimpl.ProvideService, + oauthtoken.ProvideService, + oauthtokentest.ProvideService, + wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtokentest.Service)), ) func Initialize(cla setting.CommandLineArgs, opts Options, apiOpts api.ServerOptions) (*Server, error) { diff --git a/pkg/server/wireexts_oss.go b/pkg/server/wireexts_oss.go index 6fae56eaaca..b69c63daa7e 100644 --- a/pkg/server/wireexts_oss.go +++ b/pkg/server/wireexts_oss.go @@ -8,8 +8,6 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" - "github.com/grafana/grafana/pkg/plugins/backendplugin/provider" - "github.com/grafana/grafana/pkg/plugins/manager/signature" "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/server/backgroundsvcs" "github.com/grafana/grafana/pkg/server/usagestatssvcs" @@ -29,6 +27,7 @@ import ( "github.com/grafana/grafana/pkg/services/licensing" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login/authinfoservice" + "github.com/grafana/grafana/pkg/services/pluginsintegration" "github.com/grafana/grafana/pkg/services/provisioning" "github.com/grafana/grafana/pkg/services/searchusers" "github.com/grafana/grafana/pkg/services/searchusers/filters" @@ -71,10 +70,6 @@ var wireExtsBasicSet = wire.NewSet( wire.Bind(new(user.SearchUserFilter), new(*filters.OSSSearchUserFilter)), searchusers.ProvideUsersService, wire.Bind(new(searchusers.Service), new(*searchusers.OSSService)), - signature.ProvideOSSAuthorizer, - wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)), - provider.ProvideService, - wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)), osskmsproviders.ProvideService, wire.Bind(new(kmsproviders.Service), new(osskmsproviders.Service)), ldap.ProvideGroupsService, @@ -85,6 +80,7 @@ var wireExtsBasicSet = wire.NewSet( wire.Bind(new(registry.UsageStatsProvidersRegistry), new(*usagestatssvcs.UsageStatsProvidersRegistry)), ossaccesscontrol.ProvideDatasourcePermissionsService, wire.Bind(new(accesscontrol.DatasourcePermissionsService), new(*ossaccesscontrol.DatasourcePermissionsService)), + pluginsintegration.WireExtensionSet, ) var wireExtsSet = wire.NewSet( diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index 250a8f87ead..4dffa911b48 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -79,7 +79,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs // IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source. func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { - return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool() + return IsOAuthPassThruEnabled(ds) } // HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User @@ -204,6 +204,11 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.Us return token, nil } +// IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source. +func IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { + return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool() +} + // tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in func tokensEq(t1, t2 *oauth2.Token) bool { return t1.AccessToken == t2.AccessToken && diff --git a/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go new file mode 100644 index 00000000000..230b2d7e290 --- /dev/null +++ b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go @@ -0,0 +1,41 @@ +package oauthtokentest + +import ( + "context" + + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/oauthtoken" + "github.com/grafana/grafana/pkg/services/user" + "golang.org/x/oauth2" +) + +// Service an OAuth token service suitable for tests. +type Service struct { + Token *oauth2.Token +} + +// ProvideService provides an OAuth token service suitable for tests. +func ProvideService() *Service { + return &Service{} +} + +func (s *Service) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { + return s.Token +} + +func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { + return oauthtoken.IsOAuthPassThruEnabled(ds) +} + +func (s *Service) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + return nil, false, nil +} + +func (s *Service) TryTokenRefresh(context.Context, *models.UserAuth) error { + return nil +} + +func (s *Service) InvalidateOAuthTokens(context.Context, *models.UserAuth) error { + return nil +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go new file mode 100644 index 00000000000..10a7c98be84 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go @@ -0,0 +1,87 @@ +package clientmiddleware + +import ( + "context" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/services/contexthandler" +) + +// NewClearAuthHeadersMiddleware creates a new plugins.ClientMiddleware +// that will clear any outgoing HTTP headers that was part of the incoming +// HTTP request and used when authenticating to Grafana. +func NewClearAuthHeadersMiddleware() plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &ClearAuthHeadersMiddleware{ + next: next, + } + }) +} + +type ClearAuthHeadersMiddleware struct { + next plugins.Client +} + +func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, pCtx backend.PluginContext, req interface{}) context.Context { + reqCtx := contexthandler.FromContext(ctx) + // if no HTTP request context skip middleware + if req == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil { + return ctx + } + + list := contexthandler.AuthHTTPHeaderListFromContext(ctx) + if list != nil { + ctx = sdkhttpclient.WithContextualMiddleware(ctx, httpclientprovider.DeleteHeadersMiddleware(list.Items...)) + } + + return ctx +} + +func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return m.next.QueryData(ctx, req) + } + + ctx = m.clearHeaders(ctx, req.PluginContext, req) + + return m.next.QueryData(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return m.next.CallResource(ctx, req, sender) + } + + ctx = m.clearHeaders(ctx, req.PluginContext, req) + + return m.next.CallResource(ctx, req, sender) +} + +func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return m.next.CheckHealth(ctx, req) + } + + ctx = m.clearHeaders(ctx, req.PluginContext, req) + + return m.next.CheckHealth(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + return m.next.CollectMetrics(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + return m.next.SubscribeStream(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + return m.next.PublishStream(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + return m.next.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go new file mode 100644 index 00000000000..195b3bd75dd --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go @@ -0,0 +1,310 @@ +package clientmiddleware + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/user" + "github.com/stretchr/testify/require" +) + +func TestClearAuthHeadersMiddleware(t *testing.T) { + const otherHeader = "test" + + t.Run("When no auth headers in reqContext", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + req.Header.Set(otherHeader, "test") + + t.Run("And requests are for a datasource", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), + ) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } + + t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 0) + }) + }) + + t.Run("And requests are for an app", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), + ) + + pluginCtx := backend.PluginContext{ + AppInstanceSettings: &backend.AppInstanceSettings{}, + } + + t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 0) + }) + }) + }) + + t.Run("When auth headers in reqContext", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + t.Run("And requests are for a datasource", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), + ) + + const customHeader = "X-Custom" + req.Header.Set(customHeader, "val") + ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) + req = req.WithContext(ctx) + + const otherHeader = "X-Other" + req.Header.Set(otherHeader, "test") + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } + + t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + }) + + t.Run("And requests are for an app", func(t *testing.T) { + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), + ) + + const customHeader = "X-Custom" + req.Header.Set(customHeader, "val") + ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) + req = req.WithContext(ctx) + + const otherHeader = "X-Other" + req.Header.Set(otherHeader, "test") + + pluginCtx := backend.PluginContext{ + AppInstanceSettings: &backend.AppInstanceSettings{}, + } + + t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Empty(t, reqClone.Header[customHeader]) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + }) + }) +} + +var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Request: req, + Body: io.NopCloser(bytes.NewBufferString("")), + }, nil +}) diff --git a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go new file mode 100644 index 00000000000..1c0824cc084 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware.go @@ -0,0 +1,126 @@ +package clientmiddleware + +import ( + "context" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/util/proxyutil" +) + +const cookieHeaderName = "Cookie" + +// NewCookiesMiddleware creates a new plugins.ClientMiddleware that will +// forward incoming HTTP request Cookies to outgoing plugins.Client and +// HTTP requests if the datasource has enabled forwarding of cookies (keepCookies). +func NewCookiesMiddleware(skipCookiesNames []string) plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &CookiesMiddleware{ + next: next, + skipCookiesNames: skipCookiesNames, + } + }) +} + +type CookiesMiddleware struct { + next plugins.Client + skipCookiesNames []string +} + +func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) { + reqCtx := contexthandler.FromContext(ctx) + // if request not for a datasource or no HTTP request context skip middleware + if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil { + return ctx, nil + } + + settings := pCtx.DataSourceInstanceSettings + jsonDataBytes, err := simplejson.NewJson(settings.JSONData) + if err != nil { + return ctx, err + } + + ds := &datasources.DataSource{ + Id: settings.ID, + OrgId: pCtx.OrgID, + JsonData: jsonDataBytes, + Updated: settings.Updated, + } + + proxyutil.ClearCookieHeader(reqCtx.Req, ds.AllowedCookies(), m.skipCookiesNames) + + if cookieStr := reqCtx.Req.Header.Get(cookieHeaderName); cookieStr != "" { + switch t := req.(type) { + case *backend.QueryDataRequest: + t.Headers[cookieHeaderName] = cookieStr + case *backend.CheckHealthRequest: + t.Headers[cookieHeaderName] = cookieStr + case *backend.CallResourceRequest: + t.Headers[cookieHeaderName] = []string{cookieStr} + } + } + + ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(reqCtx.Req.Cookies(), ds.AllowedCookies(), m.skipCookiesNames)) + + return ctx, nil +} + +func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return m.next.QueryData(ctx, req) + } + + newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.next.QueryData(newCtx, req) +} + +func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return m.next.CallResource(ctx, req, sender) + } + + newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + if err != nil { + return err + } + + return m.next.CallResource(newCtx, req, sender) +} + +func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return m.next.CheckHealth(ctx, req) + } + + newCtx, err := m.applyCookies(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.next.CheckHealth(newCtx, req) +} + +func (m *CookiesMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + return m.next.CollectMetrics(ctx, req) +} + +func (m *CookiesMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + return m.next.SubscribeStream(ctx, req) +} + +func (m *CookiesMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + return m.next.PublishStream(ctx, req) +} + +func (m *CookiesMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + return m.next.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go new file mode 100644 index 00000000000..a1493ea9742 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/cookies_middleware_test.go @@ -0,0 +1,223 @@ +package clientmiddleware + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" + "github.com/grafana/grafana/pkg/services/user" + "github.com/stretchr/testify/require" +) + +func TestCookiesMiddleware(t *testing.T) { + const otherHeader = "test" + + t.Run("When keepCookies not configured for a datasource", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "cookie1", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie2", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie3", + }) + req.Header.Set(otherHeader, "test") + + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewCookiesMiddleware([]string{"grafana_session"})), + ) + + jsonDataMap := map[string]interface{}{} + jsonDataBytes, err := json.Marshal(&jsonDataMap) + require.NoError(t, err) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{ + JSONData: jsonDataBytes, + }, + } + + t.Run("Should not forward cookies when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should not forward cookies when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + + t.Run("Should not forward cookies when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 1) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + }) + }) + + t.Run("When keepCookies configured for a datasource", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + req.AddCookie(&http.Cookie{ + Name: "cookie1", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie2", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie3", + }) + req.AddCookie(&http.Cookie{ + Name: "grafana_session", + }) + + req.Header.Set(otherHeader, "test") + + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewCookiesMiddleware([]string{"grafana_session"})), + ) + + jsonDataMap := map[string]interface{}{ + "keepCookies": []string{"cookie2", "grafana_session"}, + } + jsonDataBytes, err := json.Marshal(&jsonDataMap) + require.NoError(t, err) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{ + JSONData: jsonDataBytes, + }, + } + + t.Run("Should forward cookies when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 2) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + require.EqualValues(t, "cookie2=", cdt.QueryDataReq.Headers[cookieHeaderName]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 2) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) + }) + + t.Run("Should forward cookies when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 2) + require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) + require.Len(t, cdt.CallResourceReq.Headers[cookieHeaderName], 1) + require.EqualValues(t, "cookie2=", cdt.CallResourceReq.Headers[cookieHeaderName][0]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 2) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) + }) + + t.Run("Should forward cookies when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 2) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + require.EqualValues(t, "cookie2=", cdt.CheckHealthReq.Headers[cookieHeaderName]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 2) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName)) + }) + }) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go new file mode 100644 index 00000000000..bf8cbaac4f8 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware.go @@ -0,0 +1,155 @@ +package clientmiddleware + +import ( + "context" + "fmt" + "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/oauthtoken" +) + +// NewOAuthTokenMiddleware creates a new plugins.ClientMiddleware that will +// set OAuth token headers on outgoing plugins.Client and HTTP requests if +// the datasource has enabled Forward OAuth Identity (oauthPassThru). +func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plugins.ClientMiddleware { + return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client { + return &OAuthTokenMiddleware{ + next: next, + oAuthTokenService: oAuthTokenService, + } + }) +} + +const ( + tokenHeaderName = "Authorization" + idTokenHeaderName = "X-ID-Token" +) + +type OAuthTokenMiddleware struct { + oAuthTokenService oauthtoken.OAuthTokenService + next plugins.Client +} + +func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) { + reqCtx := contexthandler.FromContext(ctx) + // if request not for a datasource or no HTTP request context skip middleware + if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil { + return ctx, nil + } + + settings := pCtx.DataSourceInstanceSettings + jsonDataBytes, err := simplejson.NewJson(settings.JSONData) + if err != nil { + return ctx, err + } + + ds := &datasources.DataSource{ + Id: settings.ID, + OrgId: pCtx.OrgID, + JsonData: jsonDataBytes, + Updated: settings.Updated, + } + + if m.oAuthTokenService.IsOAuthPassThruEnabled(ds) { + if token := m.oAuthTokenService.GetCurrentOAuthToken(ctx, reqCtx.SignedInUser); token != nil { + authorizationHeader := fmt.Sprintf("%s %s", token.Type(), token.AccessToken) + idTokenHeader := "" + + idToken, ok := token.Extra("id_token").(string) + if ok && idToken != "" { + idTokenHeader = idToken + } + + switch t := req.(type) { + case *backend.QueryDataRequest: + t.Headers[tokenHeaderName] = authorizationHeader + if idTokenHeader != "" { + t.Headers[idTokenHeaderName] = idTokenHeader + } + case *backend.CheckHealthRequest: + t.Headers[tokenHeaderName] = authorizationHeader + if idTokenHeader != "" { + t.Headers[idTokenHeaderName] = idTokenHeader + } + case *backend.CallResourceRequest: + t.Headers[tokenHeaderName] = []string{authorizationHeader} + if idTokenHeader != "" { + t.Headers[idTokenHeaderName] = []string{idTokenHeader} + } + } + + httpHeaders := http.Header{} + httpHeaders.Set(tokenHeaderName, authorizationHeader) + + if idTokenHeader != "" { + httpHeaders.Set(idTokenHeaderName, idTokenHeader) + } + + ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.SetHeadersMiddleware(httpHeaders)) + } + } + + return ctx, nil +} + +func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if req == nil { + return m.next.QueryData(ctx, req) + } + + newCtx, err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.next.QueryData(newCtx, req) +} + +func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if req == nil { + return m.next.CallResource(ctx, req, sender) + } + + newCtx, err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return err + } + + return m.next.CallResource(newCtx, req, sender) +} + +func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if req == nil { + return m.next.CheckHealth(ctx, req) + } + + newCtx, err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.next.CheckHealth(newCtx, req) +} + +func (m *OAuthTokenMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + return m.next.CollectMetrics(ctx, req) +} + +func (m *OAuthTokenMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + return m.next.SubscribeStream(ctx, req) +} + +func (m *OAuthTokenMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + return m.next.PublishStream(ctx, req) +} + +func (m *OAuthTokenMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + return m.next.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go new file mode 100644 index 00000000000..6a22a8c8921 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/oauthtoken_middleware_test.go @@ -0,0 +1,197 @@ +package clientmiddleware + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" + "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" + "github.com/grafana/grafana/pkg/services/user" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestOAuthTokenMiddleware(t *testing.T) { + const otherHeader = "test" + + t.Run("When oauthPassThru not configured for a datasource", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + req.Header.Set(otherHeader, "test") + + oAuthTokenService := &oauthtokentest.Service{} + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewOAuthTokenMiddleware(oAuthTokenService)), + ) + + jsonDataMap := map[string]interface{}{} + jsonDataBytes, err := json.Marshal(&jsonDataMap) + require.NoError(t, err) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{ + JSONData: jsonDataBytes, + }, + } + + t.Run("Should not forward OAuth Identity when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 1) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not forward OAuth Identity when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 1) + require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 0) + }) + + t.Run("Should not forward OAuth Identity when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 1) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 0) + }) + }) + + t.Run("When oauthPassThru configured for a datasource", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) + require.NoError(t, err) + + req.Header.Set(otherHeader, "test") + + token := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + } + token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) + oAuthTokenService := &oauthtokentest.Service{ + Token: token, + } + cdt := clienttest.NewClientDecoratorTest(t, + clienttest.WithReqContext(req, &user.SignedInUser{}), + clienttest.WithMiddlewares(NewOAuthTokenMiddleware(oAuthTokenService)), + ) + + jsonDataMap := map[string]interface{}{ + "oauthPassThru": true, + } + jsonDataBytes, err := json.Marshal(&jsonDataMap) + require.NoError(t, err) + + pluginCtx := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{ + JSONData: jsonDataBytes, + }, + } + + t.Run("Should forward OAuth Identity when calling QueryData", func(t *testing.T) { + _, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.QueryDataReq) + require.Len(t, cdt.QueryDataReq.Headers, 3) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName]) + require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 3) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) + }) + + t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) { + err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{ + PluginContext: pluginCtx, + Headers: map[string][]string{otherHeader: {"test"}}, + }, nopCallResourceSender) + require.NoError(t, err) + require.NotNil(t, cdt.CallResourceReq) + require.Len(t, cdt.CallResourceReq.Headers, 3) + require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) + require.Len(t, cdt.CallResourceReq.Headers[tokenHeaderName], 1) + require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0]) + require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1) + require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 3) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) + }) + + t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) { + _, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.CheckHealthReq) + require.Len(t, cdt.CheckHealthReq.Headers, 3) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName]) + require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName]) + + middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx) + require.Len(t, middlewares, 1) + require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + + reqClone := req.Clone(req.Context()) + res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Len(t, reqClone.Header, 3) + require.Equal(t, "test", reqClone.Header.Get(otherHeader)) + require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName)) + require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName)) + }) + }) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/testing.go b/pkg/services/pluginsintegration/clientmiddleware/testing.go new file mode 100644 index 00000000000..d26c462f388 --- /dev/null +++ b/pkg/services/pluginsintegration/clientmiddleware/testing.go @@ -0,0 +1,13 @@ +package clientmiddleware + +import "github.com/grafana/grafana-plugin-sdk-go/backend" + +type callResourceResponseSenderFunc func(res *backend.CallResourceResponse) error + +func (fn callResourceResponseSenderFunc) Send(res *backend.CallResourceResponse) error { + return fn(res) +} + +var nopCallResourceSender = callResourceResponseSenderFunc(func(res *backend.CallResourceResponse) error { + return nil +}) diff --git a/pkg/services/pluginsintegration/doc.go b/pkg/services/pluginsintegration/doc.go new file mode 100644 index 00000000000..8c83c1a4ef5 --- /dev/null +++ b/pkg/services/pluginsintegration/doc.go @@ -0,0 +1,4 @@ +// package pluginsintegration instantiate the plugins +// package (pkg/plugins) and all of its services/dependencies that +// Grafana needs to provide plugin support. +package pluginsintegration diff --git a/pkg/services/pluginsintegration/pluginsintegration.go b/pkg/services/pluginsintegration/pluginsintegration.go new file mode 100644 index 00000000000..cf9c38e76e5 --- /dev/null +++ b/pkg/services/pluginsintegration/pluginsintegration.go @@ -0,0 +1,81 @@ +package pluginsintegration + +import ( + "github.com/google/wire" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin" + "github.com/grafana/grafana/pkg/plugins/backendplugin/provider" + "github.com/grafana/grafana/pkg/plugins/config" + "github.com/grafana/grafana/pkg/plugins/manager" + "github.com/grafana/grafana/pkg/plugins/manager/client" + "github.com/grafana/grafana/pkg/plugins/manager/loader" + "github.com/grafana/grafana/pkg/plugins/manager/process" + "github.com/grafana/grafana/pkg/plugins/manager/registry" + "github.com/grafana/grafana/pkg/plugins/manager/signature" + "github.com/grafana/grafana/pkg/plugins/manager/store" + "github.com/grafana/grafana/pkg/plugins/plugincontext" + "github.com/grafana/grafana/pkg/plugins/repo" + "github.com/grafana/grafana/pkg/services/oauthtoken" + "github.com/grafana/grafana/pkg/services/pluginsintegration/clientmiddleware" + "github.com/grafana/grafana/pkg/setting" +) + +// WireSet provides a wire.ProviderSet of plugin providers. +var WireSet = wire.NewSet( + config.ProvideConfig, + store.ProvideService, + wire.Bind(new(plugins.Store), new(*store.Service)), + wire.Bind(new(plugins.RendererManager), new(*store.Service)), + wire.Bind(new(plugins.SecretsPluginManager), new(*store.Service)), + wire.Bind(new(plugins.StaticRouteResolver), new(*store.Service)), + ProvideClientDecorator, + wire.Bind(new(plugins.Client), new(*client.Decorator)), + process.ProvideService, + wire.Bind(new(process.Service), new(*process.Manager)), + coreplugin.ProvideCoreRegistry, + loader.ProvideService, + wire.Bind(new(loader.Service), new(*loader.Loader)), + wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)), + manager.ProvideInstaller, + wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)), + registry.ProvideService, + wire.Bind(new(registry.Service), new(*registry.InMemory)), + repo.ProvideService, + wire.Bind(new(repo.Service), new(*repo.Manager)), + plugincontext.ProvideService, +) + +// WireExtensionSet provides a wire.ProviderSet of plugin providers that can be +// extended. +var WireExtensionSet = wire.NewSet( + provider.ProvideService, + wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)), + signature.ProvideOSSAuthorizer, + wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)), +) + +func ProvideClientDecorator(cfg *setting.Cfg, pCfg *config.Cfg, + pluginRegistry registry.Service, + oAuthTokenService oauthtoken.OAuthTokenService) (*client.Decorator, error) { + return NewClientDecorator(cfg, pCfg, pluginRegistry, oAuthTokenService) +} + +func NewClientDecorator(cfg *setting.Cfg, pCfg *config.Cfg, + pluginRegistry registry.Service, + oAuthTokenService oauthtoken.OAuthTokenService) (*client.Decorator, error) { + c := client.ProvideService(pluginRegistry, pCfg) + middlewares := CreateMiddlewares(cfg, oAuthTokenService) + + return client.NewDecorator(c, middlewares...) +} + +func CreateMiddlewares(cfg *setting.Cfg, oAuthTokenService oauthtoken.OAuthTokenService) []plugins.ClientMiddleware { + skipCookiesNames := []string{cfg.LoginCookieName} + middlewares := []plugins.ClientMiddleware{ + clientmiddleware.NewClearAuthHeadersMiddleware(), + clientmiddleware.NewOAuthTokenMiddleware(oAuthTokenService), + clientmiddleware.NewCookiesMiddleware(skipCookiesNames), + } + + return middlewares +} diff --git a/pkg/services/publicdashboards/api/common_test.go b/pkg/services/publicdashboards/api/common_test.go index 116c952db6b..a2cf1bb6238 100644 --- a/pkg/services/publicdashboards/api/common_test.go +++ b/pkg/services/publicdashboards/api/common_test.go @@ -10,7 +10,6 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/infra/db" @@ -137,7 +136,6 @@ func buildQueryDataService(t *testing.T, cs datasources.CacheService, fpc *fakeP &fakePluginRequestValidator{}, &fakeDatasources.FakeDataSourceService{}, fpc, - &fakeOAuthTokenService{}, ) } @@ -150,31 +148,6 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request) return rv.err } -type fakeOAuthTokenService struct { - passThruEnabled bool - token *oauth2.Token -} - -func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { - return ts.token -} - -func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool { - return ts.passThruEnabled -} - -func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { - return nil, false, nil -} - -func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { - return nil -} - -func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { - return nil -} - // copied from pkg/api/plugins_test.go type fakePluginClient struct { plugins.Client diff --git a/pkg/services/query/models.go b/pkg/services/query/models.go new file mode 100644 index 00000000000..7e209923575 --- /dev/null +++ b/pkg/services/query/models.go @@ -0,0 +1,90 @@ +package query + +import ( + "context" + "strings" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/datasources" +) + +type parsedQuery struct { + datasource *datasources.DataSource + query backend.DataQuery + rawQuery *simplejson.Json +} + +type parsedRequest struct { + hasExpression bool + parsedQueries map[string][]parsedQuery + dsTypes map[string]bool +} + +func (pr parsedRequest) getFlattenedQueries() []parsedQuery { + queries := make([]parsedQuery, 0) + for _, pq := range pr.parsedQueries { + queries = append(queries, pq...) + } + return queries +} + +func (pr parsedRequest) validateRequest(ctx context.Context) error { + reqCtx := contexthandler.FromContext(ctx) + + if reqCtx == nil || reqCtx.Req == nil { + return nil + } + + httpReq := reqCtx.Req + + if pr.hasExpression { + hasExpr := httpReq.URL.Query().Get("expression") + if hasExpr == "" || hasExpr == "true" { + return nil + } + return ErrQueryParamMismatch + } + + vals := splitHeaders(httpReq.Header.Values(HeaderDatasourceUID)) + count := len(vals) + if count > 0 { // header exists + if count != len(pr.parsedQueries) { + return ErrQueryParamMismatch + } + for _, t := range vals { + if pr.parsedQueries[t] == nil { + return ErrQueryParamMismatch + } + } + } + + vals = splitHeaders(httpReq.Header.Values(HeaderPluginID)) + count = len(vals) + if count > 0 { // header exists + if count != len(pr.dsTypes) { + return ErrQueryParamMismatch + } + for _, t := range vals { + if !pr.dsTypes[t] { + return ErrQueryParamMismatch + } + } + } + return nil +} + +func splitHeaders(headers []string) []string { + out := []string{} + for _, v := range headers { + if strings.Contains(v, ",") { + for _, sub := range strings.Split(v, ",") { + out = append(out, strings.TrimSpace(sub)) + } + } else { + out = append(out, v) + } + } + return out +} diff --git a/pkg/services/query/query.go b/pkg/services/query/query.go index 5893f4573f5..2c0f730c0a9 100644 --- a/pkg/services/query/query.go +++ b/pkg/services/query/query.go @@ -5,26 +5,21 @@ import ( "fmt" "time" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/expr" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/adapters" "github.com/grafana/grafana/pkg/services/datasources" - "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/grafanads" "github.com/grafana/grafana/pkg/tsdb/legacydata" "github.com/grafana/grafana/pkg/util/errutil" - "github.com/grafana/grafana/pkg/util/proxyutil" "golang.org/x/sync/errgroup" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" ) const ( @@ -41,7 +36,6 @@ func ProvideService( pluginRequestValidator models.PluginRequestValidator, dataSourceService datasources.DataSourceService, pluginClient plugins.Client, - oAuthTokenService oauthtoken.OAuthTokenService, ) *Service { g := &Service{ cfg: cfg, @@ -50,7 +44,6 @@ func ProvideService( pluginRequestValidator: pluginRequestValidator, dataSourceService: dataSourceService, pluginClient: pluginClient, - oAuthTokenService: oAuthTokenService, log: log.New("query_data"), } g.log.Info("Query Service initialization") @@ -64,7 +57,6 @@ type Service struct { pluginRequestValidator models.PluginRequestValidator dataSourceService datasources.DataSourceService pluginClient plugins.Client - oAuthTokenService oauthtoken.OAuthTokenService log log.Logger } @@ -175,9 +167,6 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser exprReq.OrgId = user.OrgID } - disallowedCookies := []string{s.cfg.LoginCookieName} - queryEnrichers := parsedReq.createDataSourceQueryEnrichers(ctx, user, s.oAuthTokenService, disallowedCookies) - for _, pq := range parsedReq.getFlattenedQueries() { if pq.datasource == nil { return nil, ErrMissingDataSourceInfo.Build(errutil.TemplateData{ @@ -198,7 +187,6 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser From: pq.query.TimeRange.From, To: pq.query.TimeRange.To, }, - QueryEnricher: queryEnrichers[pq.datasource.Uid], }) } @@ -240,39 +228,10 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si Queries: []backend.DataQuery{}, } - disallowedCookies := []string{s.cfg.LoginCookieName} - middlewares := []httpclient.Middleware{} - if parsedReq.httpRequest != nil { - middlewares = append(middlewares, - httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), disallowedCookies), - ) - } - - if s.oAuthTokenService.IsOAuthPassThruEnabled(ds) { - if token := s.oAuthTokenService.GetCurrentOAuthToken(ctx, user); token != nil { - req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) - - idToken, ok := token.Extra("id_token").(string) - if ok && idToken != "" { - req.Headers["X-ID-Token"] = idToken - } - middlewares = append(middlewares, httpclientprovider.ForwardedOAuthIdentityMiddleware(token)) - } - } - - if parsedReq.httpRequest != nil { - proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), disallowedCookies) - if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" { - req.Headers["Cookie"] = cookieStr - } - } - for _, q := range queries { req.Queries = append(req.Queries, q.query) } - ctx = httpclient.WithContextualMiddleware(ctx, middlewares...) - return s.pluginClient.QueryData(ctx, req) } @@ -335,11 +294,7 @@ func (s *Service) parseMetricRequest(ctx context.Context, user *user.SignedInUse }) } - if reqDTO.HTTPRequest != nil { - req.httpRequest = reqDTO.HTTPRequest - } - - _ = req.validateRequest() + _ = req.validateRequest(ctx) return req, nil // TODO req.validateRequest() } diff --git a/pkg/services/query/query_parsing.go b/pkg/services/query/query_parsing.go deleted file mode 100644 index 9021d6a658c..00000000000 --- a/pkg/services/query/query_parsing.go +++ /dev/null @@ -1,157 +0,0 @@ -package query - -import ( - "context" - "fmt" - "net/http" - "strings" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/components/simplejson" - "github.com/grafana/grafana/pkg/expr" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" - "github.com/grafana/grafana/pkg/services/datasources" - "github.com/grafana/grafana/pkg/services/oauthtoken" - "github.com/grafana/grafana/pkg/services/user" - "github.com/grafana/grafana/pkg/util/proxyutil" - "golang.org/x/oauth2" -) - -type parsedQuery struct { - datasource *datasources.DataSource - query backend.DataQuery - rawQuery *simplejson.Json -} - -type parsedRequest struct { - hasExpression bool - parsedQueries map[string][]parsedQuery - dsTypes map[string]bool - httpRequest *http.Request -} - -func (pr parsedRequest) getFlattenedQueries() []parsedQuery { - queries := make([]parsedQuery, 0) - for _, pq := range pr.parsedQueries { - queries = append(queries, pq...) - } - return queries -} - -func (pr parsedRequest) validateRequest() error { - if pr.httpRequest == nil { - return nil - } - - if pr.hasExpression { - hasExpr := pr.httpRequest.URL.Query().Get("expression") - if hasExpr == "" || hasExpr == "true" { - return nil - } - return ErrQueryParamMismatch - } - - vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID)) - count := len(vals) - if count > 0 { // header exists - if count != len(pr.parsedQueries) { - return ErrQueryParamMismatch - } - for _, t := range vals { - if pr.parsedQueries[t] == nil { - return ErrQueryParamMismatch - } - } - } - - vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID)) - count = len(vals) - if count > 0 { // header exists - if count != len(pr.dsTypes) { - return ErrQueryParamMismatch - } - for _, t := range vals { - if !pr.dsTypes[t] { - return ErrQueryParamMismatch - } - } - } - return nil -} - -func (pr parsedRequest) createDataSourceQueryEnrichers(ctx context.Context, signedInUser *user.SignedInUser, oAuthTokenService oauthtoken.OAuthTokenService, disallowedCookies []string) map[string]expr.QueryDataRequestEnricher { - datasourcesHeaderProvider := map[string]expr.QueryDataRequestEnricher{} - - if pr.httpRequest == nil { - return datasourcesHeaderProvider - } - - for uid, queries := range pr.parsedQueries { - if expr.IsDataSource(uid) { - continue - } - - if len(queries) == 0 || queries[0].datasource == nil { - continue - } - - if _, exists := datasourcesHeaderProvider[uid]; exists { - continue - } - - ds := queries[0].datasource - allowedCookies := ds.AllowedCookies() - clonedReq := pr.httpRequest.Clone(pr.httpRequest.Context()) - - var token *oauth2.Token - - if oAuthTokenService.IsOAuthPassThruEnabled(ds) { - token = oAuthTokenService.GetCurrentOAuthToken(ctx, signedInUser) - } - - datasourcesHeaderProvider[uid] = func(ctx context.Context, req *backend.QueryDataRequest) context.Context { - if len(req.Headers) == 0 { - req.Headers = map[string]string{} - } - - if len(allowedCookies) > 0 { - proxyutil.ClearCookieHeader(clonedReq, allowedCookies, disallowedCookies) - if cookieStr := clonedReq.Header.Get("Cookie"); cookieStr != "" { - req.Headers["Cookie"] = cookieStr - } - - ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(clonedReq.Cookies(), allowedCookies, disallowedCookies)) - } - - if token != nil { - req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) - - idToken, ok := token.Extra("id_token").(string) - if ok && idToken != "" { - req.Headers["X-ID-Token"] = idToken - } - - ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedOAuthIdentityMiddleware(token)) - } - - return ctx - } - } - - return datasourcesHeaderProvider -} - -func splitHeaders(headers []string) []string { - out := []string{} - for _, v := range headers { - if strings.Contains(v, ",") { - for _, sub := range strings.Split(v, ",") { - out = append(out, strings.TrimSpace(sub)) - } - } else { - out = append(out, v) - } - } - return out -} diff --git a/pkg/services/query/query_test.go b/pkg/services/query/query_test.go index 2dd56fde9cd..e9c9c4a5c57 100644 --- a/pkg/services/query/query_test.go +++ b/pkg/services/query/query_test.go @@ -5,24 +5,22 @@ import ( "context" "errors" "net/http" - "net/http/httptest" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/expr" "github.com/grafana/grafana/pkg/infra/db" - "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/plugins" acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" + "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" "github.com/grafana/grafana/pkg/services/datasources" fakeDatasources "github.com/grafana/grafana/pkg/services/datasources/fakes" dsSvc "github.com/grafana/grafana/pkg/services/datasources/service" @@ -33,35 +31,12 @@ import ( secretsmng "github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/web" ) func TestParseMetricRequest(t *testing.T) { t.Run("Test a simple single datasource query", func(t *testing.T) { tc := setup(t) - json, err := simplejson.NewJson([]byte(`{ - "keepCookies": [ "cookie1", "cookie3", "login" ] - }`)) - require.NoError(t, err) - tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { - if datasourceUID == "gIEkMvIVz" { - return &datasources.DataSource{ - Uid: "gIEkMvIVz", - JsonData: json, - }, nil - } - - return nil, nil - } - - token := &oauth2.Token{ - TokenType: "bearer", - AccessToken: "access-token", - } - token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) - - tc.oauthTokenService.passThruEnabled = true - tc.oauthTokenService.token = token - mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -82,61 +57,10 @@ func TestParseMetricRequest(t *testing.T) { assert.Len(t, parsedReq.parsedQueries, 1) assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") assert.Len(t, parsedReq.getFlattenedQueries(), 2) - - t.Run("createDataSourceQueryEnrichers should return 0 enrichers when no HTTP request", func(t *testing.T) { - enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{}) - require.Empty(t, enrichers) - }) - - t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) { - parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) - - enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) - require.Len(t, enrichers, 1) - require.NotNil(t, enrichers["gIEkMvIVz"]) - req := &backend.QueryDataRequest{} - ctx := enrichers["gIEkMvIVz"](context.Background(), req) - require.Len(t, req.Headers, 3) - require.Equal(t, "Bearer access-token", req.Headers["Authorization"]) - require.Equal(t, "id-token", req.Headers["X-ID-Token"]) - require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"]) - middlewares := httpclient.ContextualMiddlewareFromContext(ctx) - require.Len(t, middlewares, 2) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName()) - }) }) t.Run("Test a single datasource query with expressions", func(t *testing.T) { tc := setup(t) - json, err := simplejson.NewJson([]byte(`{ - "keepCookies": [ "cookie1", "cookie3", "login" ] - }`)) - require.NoError(t, err) - tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { - if datasourceUID == "gIEkMvIVz" { - return &datasources.DataSource{ - Uid: "gIEkMvIVz", - JsonData: json, - }, nil - } - - return nil, nil - } - - token := &oauth2.Token{ - TokenType: "bearer", - AccessToken: "access-token", - } - token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) - - tc.oauthTokenService.passThruEnabled = true - tc.oauthTokenService.token = token - mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -156,75 +80,27 @@ func TestParseMetricRequest(t *testing.T) { parsedReq, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) require.NoError(t, err) require.NotNil(t, parsedReq) - assert.True(t, parsedReq.hasExpression) - assert.Len(t, parsedReq.parsedQueries, 2) - assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") - assert.Len(t, parsedReq.getFlattenedQueries(), 2) + require.True(t, parsedReq.hasExpression) + require.Len(t, parsedReq.parsedQueries, 2) + require.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") + require.Len(t, parsedReq.getFlattenedQueries(), 2) // Make sure we end up with something valid _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) - assert.NoError(t, err) + require.NoError(t, err) - t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) { - parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) - - enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) - require.Len(t, enrichers, 1) - require.NotNil(t, enrichers["gIEkMvIVz"]) - - req := &backend.QueryDataRequest{} - ctx := enrichers["gIEkMvIVz"](context.Background(), req) - require.Len(t, req.Headers, 3) - require.Equal(t, "Bearer access-token", req.Headers["Authorization"]) - require.Equal(t, "id-token", req.Headers["X-ID-Token"]) - require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"]) - middlewares := httpclient.ContextualMiddlewareFromContext(ctx) - require.Len(t, middlewares, 2) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName()) + t.Run("Should forward user and org ID to QueryData from expression request", func(t *testing.T) { + require.NotNil(t, tc.pluginContext.req) + require.NotNil(t, tc.pluginContext.req.PluginContext.User) + require.Equal(t, tc.signedInUser.Login, tc.pluginContext.req.PluginContext.User.Login) + require.Equal(t, tc.signedInUser.Name, tc.pluginContext.req.PluginContext.User.Name) + require.Equal(t, tc.signedInUser.Email, tc.pluginContext.req.PluginContext.User.Email) + require.Equal(t, string(tc.signedInUser.OrgRole), tc.pluginContext.req.PluginContext.User.Role) + require.Equal(t, tc.signedInUser.OrgID, tc.pluginContext.req.PluginContext.OrgID) }) }) t.Run("Test a simple mixed datasource query", func(t *testing.T) { tc := setup(t) - json, err := simplejson.NewJson([]byte(`{ - "keepCookies": [ "cookie1", "cookie3", "login" ] - }`)) - require.NoError(t, err) - json2, err := simplejson.NewJson([]byte(`{ - "keepCookies": [ "cookie2" ] - }`)) - require.NoError(t, err) - tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { - if datasourceUID == "gIEkMvIVz" { - return &datasources.DataSource{ - Uid: "gIEkMvIVz", - JsonData: json, - }, nil - } - - if datasourceUID == "sEx6ZvSVk" { - return &datasources.DataSource{ - Uid: "sEx6ZvSVk", - JsonData: json2, - }, nil - } - - return nil, nil - } - - token := &oauth2.Token{ - TokenType: "bearer", - AccessToken: "access-token", - } - token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) - - tc.oauthTokenService.passThruEnabled = true - tc.oauthTokenService.token = token - mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -254,43 +130,6 @@ func TestParseMetricRequest(t *testing.T) { assert.Contains(t, parsedReq.parsedQueries, "sEx6ZvSVk") assert.Len(t, parsedReq.parsedQueries["sEx6ZvSVk"], 2) assert.Len(t, parsedReq.getFlattenedQueries(), 3) - - t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) { - parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) - parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) - - enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) - require.Len(t, enrichers, 2) - - enricherOne := enrichers["gIEkMvIVz"] - require.NotNil(t, enricherOne) - reqOne := &backend.QueryDataRequest{} - ctx := enricherOne(context.Background(), reqOne) - require.Len(t, reqOne.Headers, 3) - require.Equal(t, "Bearer access-token", reqOne.Headers["Authorization"]) - require.Equal(t, "id-token", reqOne.Headers["X-ID-Token"]) - require.Equal(t, "cookie1=; cookie3=", reqOne.Headers["Cookie"]) - middlewaresOne := httpclient.ContextualMiddlewareFromContext(ctx) - require.Len(t, middlewaresOne, 2) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresOne[0].(httpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresOne[1].(httpclient.MiddlewareName).MiddlewareName()) - - enricherTwo := enrichers["sEx6ZvSVk"] - require.NotNil(t, enricherTwo) - reqTwo := &backend.QueryDataRequest{} - ctx = enricherTwo(context.Background(), reqTwo) - require.Len(t, reqTwo.Headers, 3) - require.Equal(t, "Bearer access-token", reqTwo.Headers["Authorization"]) - require.Equal(t, "id-token", reqTwo.Headers["X-ID-Token"]) - require.Equal(t, "cookie2=", reqTwo.Headers["Cookie"]) - middlewaresTwo := httpclient.ContextualMiddlewareFromContext(ctx) - require.Len(t, middlewaresTwo, 2) - require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresTwo[0].(httpclient.MiddlewareName).MiddlewareName()) - require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresTwo[1].(httpclient.MiddlewareName).MiddlewareName()) - }) }) t.Run("Test a mixed datasource query with expressions", func(t *testing.T) { @@ -352,14 +191,6 @@ func TestParseMetricRequest(t *testing.T) { // Make sure we end up with something valid _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) assert.NoError(t, err) - - t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) { - parsedReq.httpRequest = &http.Request{} - enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{}) - require.Len(t, enrichers, 2) - require.NotNil(t, enrichers["gIEkMvIVz"]) - require.NotNil(t, enrichers["sEx6ZvSVk"]) - }) }) t.Run("Header validation", func(t *testing.T) { @@ -377,23 +208,30 @@ func TestParseMetricRequest(t *testing.T) { "type": "testdata" } }`) - httpreq, _ := http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{})) + httpreq, err := http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{})) + require.NoError(t, err) + + reqCtx := &models.ReqContext{ + Context: &web.Context{}, + } + ctx := ctxkey.Set(context.Background(), reqCtx) + + *httpreq = *httpreq.WithContext(ctx) + reqCtx.Req = httpreq + httpreq.Header.Add("X-Datasource-Uid", "gIEkMvIVz") - mr.HTTPRequest = httpreq - _, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) + _, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr) require.NoError(t, err) // With the second value it is OK httpreq.Header.Add("X-Datasource-Uid", "sEx6ZvSVk") - mr.HTTPRequest = httpreq - _, err = tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) + _, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr) require.NoError(t, err) // Single header with comma syntax httpreq, _ = http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{})) httpreq.Header.Set("X-Datasource-Uid", "gIEkMvIVz, sEx6ZvSVk") - mr.HTTPRequest = httpreq - _, err = tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) + _, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr) require.NoError(t, err) }) } @@ -428,7 +266,6 @@ func TestQueryDataMultipleSources(t *testing.T) { Queries: queries, Debug: false, PublicDashboardAccessToken: "abc123", - HTTPRequest: nil, } _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO) @@ -479,19 +316,27 @@ func TestQueryDataMultipleSources(t *testing.T) { Queries: queries, Debug: false, PublicDashboardAccessToken: "abc123", - HTTPRequest: nil, } // without query parameter _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO) require.NoError(t, err) - httpreq, _ := http.NewRequest(http.MethodPost, "http://localhost/ds/query?expression=true", bytes.NewReader([]byte{})) + httpreq, err := http.NewRequest(http.MethodPost, "http://localhost/ds/query?expression=true", bytes.NewReader([]byte{})) + require.NoError(t, err) + + reqCtx := &models.ReqContext{ + Context: &web.Context{}, + } + ctx := ctxkey.Set(context.Background(), reqCtx) + + *httpreq = *httpreq.WithContext(ctx) + reqCtx.Req = httpreq + httpreq.Header.Add("X-Datasource-Uid", "gIEkMvIVz") - reqDTO.HTTPRequest = httpreq // with query parameter - _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO) + _, err = tc.queryService.QueryData(httpreq.Context(), tc.signedInUser, true, reqDTO) require.NoError(t, err) }) @@ -526,7 +371,6 @@ func TestQueryDataMultipleSources(t *testing.T) { Queries: queries, Debug: false, PublicDashboardAccessToken: "abc123", - HTTPRequest: nil, } res, err := tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO) @@ -538,99 +382,10 @@ func TestQueryDataMultipleSources(t *testing.T) { }) } -func TestQueryData(t *testing.T) { - t.Run("it auth custom headers to the request", func(t *testing.T) { - token := &oauth2.Token{ - TokenType: "bearer", - AccessToken: "access-token", - } - token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) - - tc := setup(t) - tc.oauthTokenService.passThruEnabled = true - tc.oauthTokenService.token = token - - metricReq := metricRequest() - httpReq, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - metricReq.HTTPRequest = httpReq - - _, err = tc.queryService.QueryData(context.Background(), nil, true, metricReq) - require.Nil(t, err) - - expected := map[string]string{ - "Authorization": "Bearer access-token", - "X-ID-Token": "id-token", - } - require.Equal(t, expected, tc.pluginContext.req.Headers) - }) - - t.Run("it doesn't add cookie header to the request when keepCookies configured and no cookies provided", func(t *testing.T) { - tc := setup(t) - json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "foo", "bar" ]}`)) - require.NoError(t, err) - tc.dataSourceCache.ds.JsonData = json - - metricReq := metricRequest() - httpReq, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - metricReq.HTTPRequest = httpReq - _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq) - require.NoError(t, err) - - require.Empty(t, tc.pluginContext.req.Headers) - }) - - t.Run("it adds cookie header to the request when keepCookies configured and cookie provided", func(t *testing.T) { - tc := setup(t) - json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "foo", "bar" ]}`)) - require.NoError(t, err) - tc.dataSourceCache.ds.JsonData = json - - metricReq := metricRequest() - httpReq, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - httpReq.AddCookie(&http.Cookie{Name: "a"}) - httpReq.AddCookie(&http.Cookie{Name: "bar", Value: "rab"}) - httpReq.AddCookie(&http.Cookie{Name: "b"}) - httpReq.AddCookie(&http.Cookie{Name: "foo", Value: "oof"}) - httpReq.AddCookie(&http.Cookie{Name: "c"}) - metricReq.HTTPRequest = httpReq - _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq) - require.NoError(t, err) - - require.Equal(t, map[string]string{"Cookie": "bar=rab; foo=oof"}, tc.pluginContext.req.Headers) - }) - - t.Run("it doesn't adds cookie header to the request when keepCookies configured with login cookie name", func(t *testing.T) { - tc := setup(t) - tc.queryService.cfg.LoginCookieName = "grafana_session" - json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "grafana_session", "bar" ]}`)) - require.NoError(t, err) - tc.dataSourceCache.ds.JsonData = json - - metricReq := metricRequest() - httpReq, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - httpReq.AddCookie(&http.Cookie{Name: "a"}) - httpReq.AddCookie(&http.Cookie{Name: "bar", Value: "rab"}) - httpReq.AddCookie(&http.Cookie{Name: "b"}) - httpReq.AddCookie(&http.Cookie{Name: "foo", Value: "oof"}) - httpReq.AddCookie(&http.Cookie{Name: "c"}) - httpReq.AddCookie(&http.Cookie{Name: tc.queryService.cfg.LoginCookieName, Value: "val"}) - metricReq.HTTPRequest = httpReq - _, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq) - require.NoError(t, err) - - require.Equal(t, map[string]string{"Cookie": "bar=rab"}, tc.pluginContext.req.Headers) - }) -} - func setup(t *testing.T) *testContext { t.Helper() pc := &fakePluginClient{} dc := &fakeDataSourceCache{ds: &datasources.DataSource{}} - tc := &fakeOAuthTokenService{} rv := &fakePluginRequestValidator{} sqlStore := db.InitTestDB(t) @@ -645,15 +400,14 @@ func setup(t *testing.T) *testContext { SimulatePluginFailure: false, } exprService := expr.ProvideService(&setting.Cfg{ExpressionsEnabled: true}, pc, fakeDatasourceService) - queryService := ProvideService(setting.NewCfg(), dc, exprService, rv, ds, pc, tc) // provider belonging to this package + queryService := ProvideService(setting.NewCfg(), dc, exprService, rv, ds, pc) // provider belonging to this package return &testContext{ pluginContext: pc, secretStore: ss, dataSourceCache: dc, - oauthTokenService: tc, pluginRequestValidator: rv, queryService: queryService, - signedInUser: &user.SignedInUser{OrgID: 1}, + signedInUser: &user.SignedInUser{OrgID: 1, Login: "login", Name: "name", Email: "email", OrgRole: roletype.RoleAdmin}, } } @@ -661,22 +415,11 @@ type testContext struct { pluginContext *fakePluginClient secretStore secretskvs.SecretsKVStore dataSourceCache *fakeDataSourceCache - oauthTokenService *fakeOAuthTokenService pluginRequestValidator *fakePluginRequestValidator queryService *Service // implementation belonging to this package signedInUser *user.SignedInUser } -func metricRequest() dtos.MetricRequest { - q, _ := simplejson.NewJson([]byte(`{"datasourceId":1}`)) - return dtos.MetricRequest{ - From: "", - To: "", - Queries: []*simplejson.Json{q}, - Debug: false, - } -} - func metricRequestWithQueries(t *testing.T, rawQueries ...string) dtos.MetricRequest { t.Helper() queries := make([]*simplejson.Json, 0) @@ -701,34 +444,8 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request) return rv.err } -type fakeOAuthTokenService struct { - passThruEnabled bool - token *oauth2.Token -} - -func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { - return ts.token -} - -func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool { - return ts.passThruEnabled -} - -func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { - return nil, false, nil -} - -func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error { - return nil -} - -func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error { - return nil -} - type fakeDataSourceCache struct { - ds *datasources.DataSource - dsByUid func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) + ds *datasources.DataSource } func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { @@ -736,10 +453,6 @@ func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID in } func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { - if c.dsByUid != nil { - return c.dsByUid(ctx, datasourceUID, user, skipCache) - } - return &datasources.DataSource{ Uid: datasourceUID, }, nil diff --git a/pkg/tests/api/elasticsearch/elasticsearch_test.go b/pkg/tests/api/elasticsearch/elasticsearch_test.go index e0d45eb1b0e..7d46ca2c8a3 100644 --- a/pkg/tests/api/elasticsearch/elasticsearch_test.go +++ b/pkg/tests/api/elasticsearch/elasticsearch_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationElasticsearch(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -107,14 +106,3 @@ func TestIntegrationElasticsearch(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/api/graphite/graphite_test.go b/pkg/tests/api/graphite/graphite_test.go index 7e84eef8de7..0b53ba56305 100644 --- a/pkg/tests/api/graphite/graphite_test.go +++ b/pkg/tests/api/graphite/graphite_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationGraphite(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -104,14 +103,3 @@ func TestIntegrationGraphite(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/api/influxdb/influxdb_test.go b/pkg/tests/api/influxdb/influxdb_test.go index de1cd72aef6..a76ba3171e7 100644 --- a/pkg/tests/api/influxdb/influxdb_test.go +++ b/pkg/tests/api/influxdb/influxdb_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationInflux(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -103,14 +102,3 @@ func TestIntegrationInflux(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/api/loki/loki_test.go b/pkg/tests/api/loki/loki_test.go index 670cec5730b..9cd826b4def 100644 --- a/pkg/tests/api/loki/loki_test.go +++ b/pkg/tests/api/loki/loki_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationLoki(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -102,14 +101,3 @@ func TestIntegrationLoki(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/api/opentdsb/opentdsb_test.go b/pkg/tests/api/opentdsb/opentdsb_test.go index 87dc100c524..12598b02f9d 100644 --- a/pkg/tests/api/opentdsb/opentdsb_test.go +++ b/pkg/tests/api/opentdsb/opentdsb_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationOpenTSDB(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -103,14 +102,3 @@ func TestIntegrationOpenTSDB(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/api/plugins/backendplugin/backendplugin_test.go b/pkg/tests/api/plugins/backendplugin/backendplugin_test.go new file mode 100644 index 00000000000..d80f7743d1e --- /dev/null +++ b/pkg/tests/api/plugins/backendplugin/backendplugin_test.go @@ -0,0 +1,619 @@ +package backendplugin + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/api/dtos" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/plugins/backendplugin" + "github.com/grafana/grafana/pkg/server" + "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/tests/testinfra" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestIntegrationBackendPlugins(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + regularQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { + t.Helper() + + return metricRequestWithQueries(t, fmt.Sprintf(`{ + "datasource": { + "uid": "%s" + } + }`, tsCtx.uid)) + } + + expressionQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest { + t.Helper() + + return metricRequestWithQueries(t, fmt.Sprintf(`{ + "refId": "A", + "datasource": { + "uid": "%s", + "type": "%s" + } + }`, tsCtx.uid, tsCtx.testPluginID), `{ + "refId": "B", + "datasource": { + "type": "__expr__", + "uid": "__expr__", + "name": "Expression" + }, + "type": "math", + "expression": "$A - 50" + }`) + } + + newTestScenario(t, "When oauth token not available", func(t *testing.T, tsCtx *testScenarioContext) { + tsCtx.testEnv.OAuthTokenService.Token = nil + + tsCtx.runCheckHealthTest(t) + tsCtx.runCallResourceTest(t) + + t.Run("regular query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx)) + }) + + t.Run("expression query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx)) + }) + }) + + newTestScenario(t, "When oauth token available", func(t *testing.T, tsCtx *testScenarioContext) { + token := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + RefreshToken: "refresh-token", + Expiry: time.Now().UTC().Add(24 * time.Hour), + } + token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) + tsCtx.testEnv.OAuthTokenService.Token = token + + tsCtx.runCheckHealthTest(t) + tsCtx.runCallResourceTest(t) + + t.Run("regular query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx)) + }) + + t.Run("expression query", func(t *testing.T) { + tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx)) + }) + }) +} + +type testScenarioContext struct { + testPluginID string + uid string + grafanaListeningAddr string + testEnv *server.TestEnv + outgoingServer *httptest.Server + outgoingRequest *http.Request + backendTestPlugin *testPlugin + rt http.RoundTripper +} + +func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx *testScenarioContext)) { + tsCtx := testScenarioContext{ + testPluginID: "test-plugin", + } + + dir, path := testinfra.CreateGrafDir(t, testinfra.GrafanaOpts{ + DisableAnonymous: true, + // EnableLog: true, + }) + + grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) + tsCtx.grafanaListeningAddr = grafanaListeningAddr + tsCtx.testEnv = testEnv + ctx := context.Background() + + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ + DefaultOrgRole: string(org.RoleAdmin), + Password: "admin", + Login: "admin", + }) + + tsCtx.outgoingServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tsCtx.outgoingRequest = r + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(tsCtx.outgoingServer.Close) + + testPlugin, backendTestPlugin := createTestPlugin(tsCtx.testPluginID) + tsCtx.backendTestPlugin = backendTestPlugin + err := testEnv.PluginRegistry.Add(ctx, testPlugin) + require.NoError(t, err) + + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "httpHeaderName1": "X-CUSTOM-HEADER", + "oauthPassThru": true, + "keepCookies": []string{"cookie1", "cookie3", "grafana_session"}, + }) + secureJSONData := map[string]string{ + "basicAuthPassword": "basicAuthPassword", + "httpHeaderValue1": "custom-header-value", + } + + tsCtx.uid = "test-plugin" + err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, &datasources.AddDataSourceCommand{ + OrgId: 1, + Access: datasources.DS_ACCESS_PROXY, + Name: "TestPlugin", + Type: tsCtx.testPluginID, + Uid: tsCtx.uid, + Url: tsCtx.outgoingServer.URL, + BasicAuth: true, + BasicAuthUser: "basicAuthUser", + JsonData: jsonData, + SecureJsonData: secureJSONData, + }) + require.NoError(t, err) + + getDataSourceQuery := &datasources.GetDataSourceQuery{ + OrgId: 1, + Uid: tsCtx.uid, + } + err = testEnv.Server.HTTPServer.DataSourcesService.GetDataSource(ctx, getDataSourceQuery) + require.NoError(t, err) + + rt, err := testEnv.Server.HTTPServer.DataSourcesService.GetHTTPTransport(ctx, getDataSourceQuery.Result, testEnv.HTTPClientProvider) + require.NoError(t, err) + + tsCtx.rt = rt + + t.Run(name, func(t *testing.T) { + callback(t, &tsCtx) + }) +} + +func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest) { + t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) { + var received *struct { + ctx context.Context + req *backend.QueryDataRequest + } + tsCtx.backendTestPlugin.QueryDataHandler = backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + received = &struct { + ctx context.Context + req *backend.QueryDataRequest + }{ctx, req} + + c := http.Client{ + Transport: tsCtx.rt, + } + outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil) + require.NoError(t, err) + resp, err := c.Do(outReq) + if err != nil { + return nil, err + } + defer func() { + if err := resp.Body.Close(); err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err) + } + }() + + _, err = io.Copy(io.Discard, resp.Body) + if err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err) + } + + return &backend.QueryDataResponse{}, nil + }) + + buf1 := &bytes.Buffer{} + err := json.NewEncoder(buf1).Encode(mr) + require.NoError(t, err) + u := fmt.Sprintf("http://admin:admin@%s/api/ds/query", tsCtx.grafanaListeningAddr) + + req, err := http.NewRequest(http.MethodPost, u, buf1) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{ + Name: "cookie1", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie2", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie3", + }) + req.AddCookie(&http.Cookie{ + Name: "grafana_session", + }) + + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, string(b)) + t.Cleanup(func() { + err := resp.Body.Close() + require.NoError(t, err) + }) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // backend query data request + require.NotNil(t, received) + require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"]) + + token := tsCtx.testEnv.OAuthTokenService.Token + + var expectedAuthHeader string + var expectedTokenHeader string + + if token != nil { + expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) + expectedTokenHeader = token.Extra("id_token").(string) + + require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"]) + require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"]) + } + + // outgoing HTTP request + require.NotNil(t, tsCtx.outgoingRequest) + require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) + require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) + + if token == nil { + username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() + require.True(t, ok) + require.Equal(t, "basicAuthUser", username) + require.Equal(t, "basicAuthPassword", pwd) + } else { + require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) + require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) + } + }) +} + +func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) { + t.Run("When calling /api/datasources/uid/:uid/health should set expected headers on outgoing CheckHealth and HTTP request", func(t *testing.T) { + var received *struct { + ctx context.Context + req *backend.CheckHealthRequest + } + tsCtx.backendTestPlugin.CheckHealthHandler = backend.CheckHealthHandlerFunc(func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + received = &struct { + ctx context.Context + req *backend.CheckHealthRequest + }{ctx, req} + + c := http.Client{ + Transport: tsCtx.rt, + } + outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil) + require.NoError(t, err) + resp, err := c.Do(outReq) + if err != nil { + return nil, err + } + defer func() { + if err := resp.Body.Close(); err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err) + } + }() + + _, err = io.Copy(io.Discard, resp.Body) + if err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err) + } + + return &backend.CheckHealthResult{ + Status: backend.HealthStatusOk, + }, nil + }) + + u := fmt.Sprintf("http://admin:admin@%s/api/datasources/uid/%s/health", tsCtx.grafanaListeningAddr, tsCtx.uid) + + req, err := http.NewRequest(http.MethodGet, u, nil) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{ + Name: "cookie1", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie2", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie3", + }) + req.AddCookie(&http.Cookie{ + Name: "grafana_session", + }) + + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, string(b)) + t.Cleanup(func() { + err := resp.Body.Close() + require.NoError(t, err) + }) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // backend query data request + require.NotNil(t, received) + require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"]) + + token := tsCtx.testEnv.OAuthTokenService.Token + + var expectedAuthHeader string + var expectedTokenHeader string + + if token != nil { + expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) + expectedTokenHeader = token.Extra("id_token").(string) + + require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"]) + require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"]) + } + + // outgoing HTTP request + require.NotNil(t, tsCtx.outgoingRequest) + require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) + require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) + + if token == nil { + username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() + require.True(t, ok) + require.Equal(t, "basicAuthUser", username) + require.Equal(t, "basicAuthPassword", pwd) + } else { + require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) + require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) + } + }) +} + +func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) { + t.Run("When calling /api/datasources/uid/:uid/resources should set expected headers on outgoing CallResource and HTTP request", func(t *testing.T) { + var received *struct { + ctx context.Context + req *backend.CallResourceRequest + } + tsCtx.backendTestPlugin.CallResourceHandler = backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + received = &struct { + ctx context.Context + req *backend.CallResourceRequest + }{ctx, req} + + c := http.Client{ + Transport: tsCtx.rt, + } + outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil) + require.NoError(t, err) + resp, err := c.Do(outReq) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err) + } + }() + + _, err = io.Copy(io.Discard, resp.Body) + if err != nil { + tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err) + } + + err = sender.Send(&backend.CallResourceResponse{ + Status: http.StatusOK, + }) + + return err + }) + + u := fmt.Sprintf("http://admin:admin@%s/api/datasources/uid/%s/resources", tsCtx.grafanaListeningAddr, tsCtx.uid) + + req, err := http.NewRequest(http.MethodGet, u, nil) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{ + Name: "cookie1", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie2", + }) + req.AddCookie(&http.Cookie{ + Name: "cookie3", + }) + req.AddCookie(&http.Cookie{ + Name: "grafana_session", + }) + + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, string(b)) + t.Cleanup(func() { + err := resp.Body.Close() + require.NoError(t, err) + }) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // backend query data request + require.NotNil(t, received) + require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"][0]) + + token := tsCtx.testEnv.OAuthTokenService.Token + + var expectedAuthHeader string + var expectedTokenHeader string + + if token != nil { + expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken) + expectedTokenHeader = token.Extra("id_token").(string) + + require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"][0]) + require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"][0]) + } + + // outgoing HTTP request + require.NotNil(t, tsCtx.outgoingRequest) + require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie")) + require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER")) + + if token == nil { + username, pwd, ok := tsCtx.outgoingRequest.BasicAuth() + require.True(t, ok) + require.Equal(t, "basicAuthUser", username) + require.Equal(t, "basicAuthPassword", pwd) + } else { + require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization")) + require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token")) + } + }) +} + +func createTestPlugin(id string) (*plugins.Plugin, *testPlugin) { + p := &plugins.Plugin{ + JSONData: plugins.JSONData{ + ID: id, + }, + Class: plugins.Core, + } + + p.SetLogger(log.New("test-plugin")) + tp := &testPlugin{ + pluginID: id, + logger: p.Logger(), + QueryDataHandler: backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + return &backend.QueryDataResponse{}, nil + }), + } + p.RegisterClient(tp) + + return p, tp +} + +type testPlugin struct { + pluginID string + logger log.Logger + backend.CheckHealthHandler + backend.CallResourceHandler + backend.QueryDataHandler + backend.StreamHandler +} + +func (tp *testPlugin) PluginID() string { + return tp.pluginID +} + +func (tp *testPlugin) Logger() log.Logger { + return tp.logger +} + +func (tp *testPlugin) Start(ctx context.Context) error { + return nil +} + +func (tp *testPlugin) Stop(ctx context.Context) error { + return nil +} + +func (tp *testPlugin) IsManaged() bool { + return true +} + +func (tp *testPlugin) Exited() bool { + return false +} + +func (tp *testPlugin) Decommission() error { + return nil +} + +func (tp *testPlugin) IsDecommissioned() bool { + return false +} + +func (tp *testPlugin) CollectMetrics(_ context.Context, _ *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) { + return nil, backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { + if tp.CheckHealthHandler != nil { + return tp.CheckHealthHandler.CheckHealth(ctx, req) + } + + return nil, backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + if tp.QueryDataHandler != nil { + return tp.QueryDataHandler.QueryData(ctx, req) + } + + return nil, backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + if tp.CallResourceHandler != nil { + return tp.CallResourceHandler.CallResource(ctx, req, sender) + } + + return backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if tp.StreamHandler != nil { + return tp.StreamHandler.SubscribeStream(ctx, req) + } + return nil, backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if tp.StreamHandler != nil { + return tp.StreamHandler.PublishStream(ctx, req) + } + return nil, backendplugin.ErrMethodNotImplemented +} + +func (tp *testPlugin) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if tp.StreamHandler != nil { + return tp.StreamHandler.RunStream(ctx, req, sender) + } + return backendplugin.ErrMethodNotImplemented +} + +func metricRequestWithQueries(t *testing.T, rawQueries ...string) dtos.MetricRequest { + t.Helper() + queries := make([]*simplejson.Json, 0) + for _, q := range rawQueries { + json, err := simplejson.NewJson([]byte(q)) + require.NoError(t, err) + queries = append(queries, json) + } + return dtos.MetricRequest{ + From: "now-1h", + To: "now", + Queries: queries, + Debug: false, + } +} diff --git a/pkg/tests/api/prometheus/prometheus_test.go b/pkg/tests/api/prometheus/prometheus_test.go index 391ad6ec73a..6284659e0ab 100644 --- a/pkg/tests/api/prometheus/prometheus_test.go +++ b/pkg/tests/api/prometheus/prometheus_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/tests/testinfra" "github.com/stretchr/testify/require" @@ -31,7 +30,7 @@ func TestIntegrationPrometheusBuffered(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -116,7 +115,7 @@ func TestIntegrationPrometheusClient(t *testing.T) { grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path) ctx := context.Background() - createUser(t, testEnv.SQLStore, user.CreateUserCommand{ + testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{ DefaultOrgRole: string(org.RoleAdmin), Password: "admin", Login: "admin", @@ -213,14 +212,3 @@ func TestIntegrationPrometheusClient(t *testing.T) { require.Equal(t, "basicAuthPassword", pwd) }) } - -func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { - t.Helper() - - store.Cfg.AutoAssignOrg = true - store.Cfg.AutoAssignOrgId = 1 - - u, err := store.CreateUser(context.Background(), cmd) - require.NoError(t, err) - return u.ID -} diff --git a/pkg/tests/testinfra/testinfra.go b/pkg/tests/testinfra/testinfra.go index fe13679fc59..d5aeef74c52 100644 --- a/pkg/tests/testinfra/testinfra.go +++ b/pkg/tests/testinfra/testinfra.go @@ -23,6 +23,7 @@ import ( "github.com/grafana/grafana/pkg/server" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -307,7 +308,13 @@ func CreateGrafDir(t *testing.T, opts ...GrafanaOpts) (string, string) { require.NoError(t, err) _, err = logSection.NewKey("enabled", "false") require.NoError(t, err) + } else { + serverSection, err := getOrCreateSection("server") + require.NoError(t, err) + _, err = serverSection.NewKey("router_logging", "true") + require.NoError(t, err) } + if o.GRPCServerAddress != "" { logSection, err := getOrCreateSection("grpc_server") require.NoError(t, err) @@ -356,3 +363,14 @@ type GrafanaOpts struct { GRPCServerAddress string QueryRetries int64 } + +func CreateUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 { + t.Helper() + + store.Cfg.AutoAssignOrg = true + store.Cfg.AutoAssignOrgId = 1 + + u, err := store.CreateUser(context.Background(), cmd) + require.NoError(t, err) + return u.ID +} diff --git a/pkg/tsdb/legacydata/service/service.go b/pkg/tsdb/legacydata/service/service.go index f46673b605f..35bdcdd9edd 100644 --- a/pkg/tsdb/legacydata/service/service.go +++ b/pkg/tsdb/legacydata/service/service.go @@ -2,7 +2,6 @@ package service import ( "context" - "fmt" "time" "github.com/grafana/grafana-plugin-sdk-go/backend" @@ -14,10 +13,6 @@ import ( "github.com/grafana/grafana/pkg/tsdb/legacydata" ) -var oAuthIsOAuthPassThruEnabledFunc = func(oAuthTokenService oauthtoken.OAuthTokenService, ds *datasources.DataSource) bool { - return oAuthTokenService.IsOAuthPassThruEnabled(ds) -} - type Service struct { pluginsClient plugins.Client oAuthTokenService oauthtoken.OAuthTokenService @@ -45,13 +40,6 @@ func (h *Service) HandleRequest(ctx context.Context, ds *datasources.DataSource, return legacydata.DataResponse{}, err } - // Attach Auth information - if oAuthIsOAuthPassThruEnabledFunc(h.oAuthTokenService, ds) { - if token := h.oAuthTokenService.GetCurrentOAuthToken(ctx, query.User); token != nil { - query.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) - } - } - resp, err := h.pluginsClient.QueryData(ctx, req) if err != nil { return legacydata.DataResponse{}, err diff --git a/pkg/tsdb/legacydata/service/service_test.go b/pkg/tsdb/legacydata/service/service_test.go index beee540e81b..493a204ab9a 100644 --- a/pkg/tsdb/legacydata/service/service_test.go +++ b/pkg/tsdb/legacydata/service/service_test.go @@ -15,7 +15,6 @@ import ( "github.com/grafana/grafana/pkg/services/datasources" datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service" "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/quota/quotatest" "github.com/grafana/grafana/pkg/services/secrets/fakes" secretskvs "github.com/grafana/grafana/pkg/services/secrets/kvstore" @@ -25,15 +24,6 @@ import ( func TestHandleRequest(t *testing.T) { t.Run("Should invoke plugin manager QueryData when handling request for query", func(t *testing.T) { - origOAuthIsOAuthPassThruEnabledFunc := oAuthIsOAuthPassThruEnabledFunc - oAuthIsOAuthPassThruEnabledFunc = func(oAuthTokenService oauthtoken.OAuthTokenService, ds *datasources.DataSource) bool { - return false - } - - t.Cleanup(func() { - oAuthIsOAuthPassThruEnabledFunc = origOAuthIsOAuthPassThruEnabledFunc - }) - client := &fakePluginsClient{} var actualReq *backend.QueryDataRequest client.QueryDataHandlerFunc = func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {