Plugins: Refactor forward of cookies, OAuth token and header modifications by introducing client middlewares (#58132)

Adding support for backend plugin client middlewares. This allows headers in outgoing 
backend plugin and HTTP requests to be modified using client middlewares.

The following client middlewares added:
Forward cookies: Will forward incoming HTTP request Cookies to outgoing plugins.Client 
and HTTP requests if the datasource has enabled forwarding of cookies (keepCookies).
Forward OAuth token: Will set OAuth token headers on outgoing plugins.Client and HTTP 
requests if the datasource has enabled Forward OAuth Identity (oauthPassThru).
Clear auth headers: Will clear any outgoing HTTP headers that was part of the incoming 
HTTP request and used when authenticating to Grafana.
The current suggested way to register client middlewares is to have a separate package, 
pluginsintegration, responsible for bootstrap/instantiate the backend plugin client with 
middlewares and/or longer term bootstrap/instantiate plugin management. 

Fixes #54135
Related to #47734
Related to #57870
Related to #41623
Related to #57065
This commit is contained in:
Marcus Efraimsson
2022-12-01 10:08:36 -08:00
committed by GitHub
parent 632ca67e3f
commit 6dbe3b555f
56 changed files with 2923 additions and 1000 deletions

3
.github/CODEOWNERS vendored
View File

@@ -103,9 +103,11 @@ go.sum @grafana/backend-platform
# Plugins
/pkg/api/pluginproxy @grafana/plugins-platform-backend
/pkg/infra/httpclient @grafana/plugins-platform-backend
/pkg/plugins @grafana/plugins-platform-backend
/pkg/services/datasourceproxy @grafana/plugins-platform-backend
/pkg/services/datasources @grafana/plugins-platform-backend
/pkg/services/pluginsintegration @grafana/plugins-platform-backend
/pkg/plugins/pfs @grafana/plugins-platform-backend @grafana/grafana-as-code
# Dashboard previews / crawler (behind feature flag)
@@ -229,6 +231,7 @@ lerna.json @grafana/frontend-ops
# Grafana Partnerships Team
/pkg/infra/httpclient/httpclientprovider/sigv4_middleware.go @grafana/grafana-partnerships-team
/pkg/infra/httpclient/httpclientprovider/sigv4_middleware_test.go @grafana/grafana-partnerships-team
# Kind system and code generation
embed.go @grafana/grafana-as-code

View File

@@ -24,7 +24,6 @@ import (
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/util/proxyutil"
"github.com/grafana/grafana/pkg/web"
)
@@ -823,21 +822,6 @@ func (hs *HTTPServer) checkDatasourceHealth(c *models.ReqContext, ds *datasource
return response.Error(http.StatusForbidden, "Access denied", err)
}
if hs.DataProxy.OAuthTokenService.IsOAuthPassThruEnabled(ds) {
if token := hs.DataProxy.OAuthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser); token != nil {
req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Headers["X-ID-Token"] = idToken
}
}
}
proxyutil.ClearCookieHeader(c.Req, ds.AllowedCookies(), []string{hs.Cfg.LoginCookieName})
if cookieStr := c.Req.Header.Get("Cookie"); cookieStr != "" {
req.Headers["Cookie"] = cookieStr
}
resp, err := hs.pluginClient.CheckHealth(c.Req.Context(), req)
if err != nil {
return translatePluginRequestErrorToAPIError(err)

View File

@@ -3,7 +3,6 @@ package dtos
import (
"crypto/md5"
"fmt"
"net/http"
"regexp"
"strings"
@@ -73,8 +72,6 @@ type MetricRequest struct {
Debug bool `json:"debug"`
PublicDashboardAccessToken string `json:"publicDashboardAccessToken"`
HTTPRequest *http.Request `json:"-"`
}
func (mr *MetricRequest) GetUniqueDatasourceTypes() []string {
@@ -98,11 +95,10 @@ func (mr *MetricRequest) GetUniqueDatasourceTypes() []string {
func (mr *MetricRequest) CloneWithQueries(queries []*simplejson.Json) MetricRequest {
return MetricRequest{
From: mr.From,
To: mr.To,
Queries: queries,
Debug: mr.Debug,
HTTPRequest: mr.HTTPRequest,
From: mr.From,
To: mr.To,
Queries: queries,
Debug: mr.Debug,
}
}

View File

@@ -52,8 +52,6 @@ func (hs *HTTPServer) QueryMetricsV2(c *models.ReqContext) response.Response {
return response.Error(http.StatusBadRequest, "bad request data", err)
}
reqDTO.HTTPRequest = c.Req
resp, err := hs.queryDataService.QueryData(c.Req.Context(), c.SignedInUser, c.SkipCache, reqDTO)
if err != nil {
return hs.handleQueryMetricsError(err)

View File

@@ -13,15 +13,12 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/plugins/config"
pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
"github.com/grafana/grafana/pkg/services/datasources"
fakeDatasources "github.com/grafana/grafana/pkg/services/datasources/fakes"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org"
@@ -46,31 +43,6 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request)
return rv.err
}
type fakeOAuthTokenService struct {
passThruEnabled bool
token *oauth2.Token
}
func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
return ts.token
}
func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return nil
}
// `/ds/query` endpoint test
func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
qds := query.ProvideService(
@@ -89,7 +61,6 @@ func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
return &backend.QueryDataResponse{Responses: resp}, nil
},
},
&fakeOAuthTokenService{},
)
serverFeatureEnabled := SetupAPITestServer(t, func(hs *HTTPServer) {
hs.queryDataService = qds
@@ -138,7 +109,6 @@ func TestAPIEndpoint_Metrics_PluginDecryptionFailure(t *testing.T) {
return &backend.QueryDataResponse{Responses: resp}, nil
},
},
&fakeOAuthTokenService{},
)
httpServer := SetupAPITestServer(t, func(hs *HTTPServer) {
hs.queryDataService = qds
@@ -292,7 +262,6 @@ func TestDataSourceQueryError(t *testing.T) {
&fakePluginRequestValidator{},
&fakeDatasources.FakeDataSourceService{},
pluginClient.ProvideService(r, &config.Cfg{}),
&fakeOAuthTokenService{},
)
hs.QuotaService = quotatest.New(false, nil)
})

View File

@@ -2,7 +2,6 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -15,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/util/proxyutil"
"github.com/grafana/grafana/pkg/web"
@@ -78,17 +76,6 @@ func (hs *HTTPServer) callPluginResourceWithDataSource(c *models.ReqContext, plu
return
}
if hs.DataProxy.OAuthTokenService.IsOAuthPassThruEnabled(ds) {
if token := hs.DataProxy.OAuthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser); token != nil {
req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Header.Add("X-ID-Token", idToken)
}
}
}
if err = hs.makePluginResourceRequest(c.Resp, req, pCtx); err != nil {
handleCallResourceError(err, c)
}
@@ -110,24 +97,6 @@ func (hs *HTTPServer) pluginResourceRequest(c *models.ReqContext) (*http.Request
}
func (hs *HTTPServer) makePluginResourceRequest(w http.ResponseWriter, req *http.Request, pCtx backend.PluginContext) error {
keepCookieModel := struct {
KeepCookies []string `json:"keepCookies"`
}{}
if dis := pCtx.DataSourceInstanceSettings; dis != nil {
err := json.Unmarshal(dis.JSONData, &keepCookieModel)
if err != nil {
hs.log.Warn("failed to unpack JSONData in datasource instance settings", "err", err)
}
}
list := contexthandler.AuthHTTPHeaderListFromContext(req.Context())
if list != nil {
for _, name := range list.Items {
req.Header.Del(name)
}
}
proxyutil.ClearCookieHeader(req, keepCookieModel.KeepCookies, []string{hs.Cfg.LoginCookieName})
proxyutil.PrepareProxyRequest(req)
body, err := io.ReadAll(req.Body)

View File

@@ -23,7 +23,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/pluginsettings"
"github.com/grafana/grafana/pkg/services/quota/quotatest"
@@ -315,11 +314,6 @@ func TestMakePluginResourceRequest(t *testing.T) {
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
const customHeader = "X-CUSTOM"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)
resp := httptest.NewRecorder()
pCtx := backend.PluginContext{}
err := hs.makePluginResourceRequest(resp, req, pCtx)
@@ -333,7 +327,6 @@ func TestMakePluginResourceRequest(t *testing.T) {
require.Equal(t, resp.Header().Get("Content-Type"), "application/json")
require.Equal(t, "sandbox", resp.Header().Get("Content-Security-Policy"))
require.Empty(t, req.Header.Get(customHeader))
}
func TestMakePluginResourceRequestSetCookieNotPresent(t *testing.T) {

View File

@@ -34,18 +34,7 @@ import (
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
pluginsCfg "github.com/grafana/grafana/pkg/plugins/config"
"github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/plugins/manager/client"
pluginDashboards "github.com/grafana/grafana/pkg/plugins/manager/dashboards"
"github.com/grafana/grafana/pkg/plugins/manager/loader"
processManager "github.com/grafana/grafana/pkg/plugins/manager/process"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
managerStore "github.com/grafana/grafana/pkg/plugins/manager/store"
"github.com/grafana/grafana/pkg/plugins/plugincontext"
"github.com/grafana/grafana/pkg/plugins/repo"
"github.com/grafana/grafana/pkg/registry/corekind"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
@@ -94,6 +83,7 @@ import (
plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service"
"github.com/grafana/grafana/pkg/services/pluginsettings"
pluginSettings "github.com/grafana/grafana/pkg/services/pluginsettings/service"
"github.com/grafana/grafana/pkg/services/pluginsintegration"
"github.com/grafana/grafana/pkg/services/preference/prefimpl"
"github.com/grafana/grafana/pkg/services/publicdashboards"
publicdashboardsApi "github.com/grafana/grafana/pkg/services/publicdashboards/api"
@@ -181,28 +171,9 @@ var wireSet = wire.NewSet(
updatechecker.ProvideGrafanaService,
updatechecker.ProvidePluginsService,
uss.ProvideService,
pluginsCfg.ProvideConfig,
registry.ProvideService,
wire.Bind(new(registry.Service), new(*registry.InMemory)),
repo.ProvideService,
wire.Bind(new(repo.Service), new(*repo.Manager)),
manager.ProvideInstaller,
wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)),
client.ProvideService,
wire.Bind(new(plugins.Client), new(*client.Service)),
managerStore.ProvideService,
wire.Bind(new(plugins.Store), new(*managerStore.Service)),
wire.Bind(new(plugins.RendererManager), new(*managerStore.Service)),
wire.Bind(new(plugins.SecretsPluginManager), new(*managerStore.Service)),
wire.Bind(new(plugins.StaticRouteResolver), new(*managerStore.Service)),
pluginsintegration.WireSet,
pluginDashboards.ProvideFileStoreManager,
wire.Bind(new(pluginDashboards.FileStore), new(*pluginDashboards.FileStoreManager)),
processManager.ProvideService,
wire.Bind(new(processManager.Service), new(*processManager.Manager)),
coreplugin.ProvideCoreRegistry,
loader.ProvideService,
wire.Bind(new(loader.Service), new(*loader.Loader)),
wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)),
cloudwatch.ProvideService,
cloudmonitoring.ProvideService,
azuremonitor.ProvideService,
@@ -236,7 +207,6 @@ var wireSet = wire.NewSet(
export.ProvideService,
live.ProvideService,
pushhttp.ProvideService,
plugincontext.ProvideService,
contexthandler.ProvideService,
jwt.ProvideService,
wire.Bind(new(models.JWTService), new(*jwt.AuthService)),

View File

@@ -7,9 +7,6 @@ import (
"github.com/google/wire"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/provider"
"github.com/grafana/grafana/pkg/plugins/manager/signature"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/server/backgroundsvcs"
"github.com/grafana/grafana/pkg/server/usagestatssvcs"
@@ -29,6 +26,7 @@ import (
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
"github.com/grafana/grafana/pkg/services/pluginsintegration"
"github.com/grafana/grafana/pkg/services/provisioning"
"github.com/grafana/grafana/pkg/services/searchusers"
"github.com/grafana/grafana/pkg/services/searchusers/filters"
@@ -71,10 +69,6 @@ var wireExtsSet = wire.NewSet(
wire.Bind(new(user.SearchUserFilter), new(*filters.OSSSearchUserFilter)),
searchusers.ProvideUsersService,
wire.Bind(new(searchusers.Service), new(*searchusers.OSSService)),
signature.ProvideOSSAuthorizer,
wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)),
provider.ProvideService,
wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)),
ldap.ProvideGroupsService,
wire.Bind(new(ldap.Groups), new(*ldap.OSSGroups)),
permissions.ProvideDatasourcePermissionsService,
@@ -85,4 +79,5 @@ var wireExtsSet = wire.NewSet(
wire.Bind(new(accesscontrol.DatasourcePermissionsService), new(*ossaccesscontrol.DatasourcePermissionsService)),
encryptionprovider.ProvideEncryptionProvider,
wire.Bind(new(encryption.Provider), new(encryptionprovider.Provider)),
pluginsintegration.WireExtensionSet,
)

View File

@@ -63,6 +63,10 @@ func (dp *DataPipeline) execute(c context.Context, now time.Time, s *Service) (m
// BuildPipeline builds a graph of the nodes, and returns the nodes in an
// executable order.
func (s *Service) buildPipeline(req *Request) (DataPipeline, error) {
if req != nil && len(req.Headers) == 0 {
req.Headers = map[string]string{}
}
graph, err := s.buildDependencyGraph(req)
if err != nil {
return nil, err
@@ -144,12 +148,11 @@ func (s *Service) buildGraph(req *Request) (*simple.DirectedGraph, error) {
}
rn := &rawNode{
Query: rawQueryProp,
RefID: query.RefID,
TimeRange: query.TimeRange,
QueryType: query.QueryType,
DataSource: query.DataSource,
QueryEnricher: query.QueryEnricher,
Query: rawQueryProp,
RefID: query.RefID,
TimeRange: query.TimeRange,
QueryType: query.QueryType,
DataSource: query.DataSource,
}
var node Node

View File

@@ -42,12 +42,11 @@ type baseNode struct {
}
type rawNode struct {
RefID string `json:"refId"`
Query map[string]interface{}
QueryType string
TimeRange TimeRange
DataSource *datasources.DataSource
QueryEnricher QueryDataRequestEnricher
RefID string `json:"refId"`
Query map[string]interface{}
QueryType string
TimeRange TimeRange
DataSource *datasources.DataSource
}
func (rn *rawNode) GetCommandType() (c CommandType, err error) {
@@ -140,9 +139,8 @@ const (
// DSNode is a DPNode that holds a datasource request.
type DSNode struct {
baseNode
query json.RawMessage
datasource *datasources.DataSource
queryEnricher QueryDataRequestEnricher
query json.RawMessage
datasource *datasources.DataSource
orgID int64
queryType string
@@ -171,15 +169,14 @@ func (s *Service) buildDSNode(dp *simple.DirectedGraph, rn *rawNode, req *Reques
id: dp.NewNode().ID(),
refID: rn.RefID,
},
orgID: req.OrgId,
query: json.RawMessage(encodedQuery),
queryType: rn.QueryType,
intervalMS: defaultIntervalMS,
maxDP: defaultMaxDP,
timeRange: rn.TimeRange,
request: *req,
datasource: rn.DataSource,
queryEnricher: rn.QueryEnricher,
orgID: req.OrgId,
query: json.RawMessage(encodedQuery),
queryType: rn.QueryType,
intervalMS: defaultIntervalMS,
maxDP: defaultMaxDP,
timeRange: rn.TimeRange,
request: *req,
datasource: rn.DataSource,
}
var floatIntervalMS float64
@@ -232,10 +229,6 @@ func (dn *DSNode) Execute(ctx context.Context, now time.Time, _ mathexp.Vars, s
Headers: dn.request.Headers,
}
if dn.queryEnricher != nil {
ctx = dn.queryEnricher(ctx, req)
}
resp, err := s.dataService.QueryData(ctx, req)
if err != nil {
return mathexp.Results{}, err

View File

@@ -38,16 +38,12 @@ type Request struct {
User *backend.User
}
// QueryDataRequestEnricher function definition for enriching a backend.QueryDataRequest request.
type QueryDataRequestEnricher func(ctx context.Context, req *backend.QueryDataRequest) context.Context
// Query is like plugins.DataSubQuery, but with a a time range, and only the UID
// for the data source. Also interval is a time.Duration.
type Query struct {
RefID string
TimeRange TimeRange
DataSource *datasources.DataSource `json:"datasource"`
QueryEnricher QueryDataRequestEnricher
JSON json.RawMessage
Interval time.Duration
QueryType string

View File

@@ -0,0 +1,27 @@
package httpclientprovider
import (
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
const DeleteHeadersMiddlewareName = "delete-headers"
// DeleteHeadersMiddleware middleware that delete headers on the outgoing
// request if header names provided.
func DeleteHeadersMiddleware(headerNames ...string) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(DeleteHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if len(headerNames) == 0 {
return next
}
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
for _, k := range headerNames {
req.Header.Del(k)
}
return next.RoundTrip(req)
})
})
}

View File

@@ -0,0 +1,66 @@
package httpclientprovider
import (
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/require"
)
func TestDeleteHeadersMiddleware(t *testing.T) {
t.Run("Without headerNames should return next http.RoundTripper", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("finalrt")
var headerNames []string
mw := DeleteHeadersMiddleware(headerNames...)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
})
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("final")
headerNames := []string{"X-Header-B", "X-Header-C"}
mw := DeleteHeadersMiddleware(headerNames...)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, DeleteHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
req.Header.Set("X-Header-A", "a")
req.Header.Set("X-Header-B", "b")
req.Header.Set("X-Header-C", "c")
req.Header.Set("X-Header-D", "d")
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
require.Equal(t, "a", req.Header.Get("X-Header-A"))
require.Empty(t, req.Header.Get("X-Header-B"))
require.Empty(t, req.Header.Get("X-Header-C"))
require.Equal(t, "d", req.Header.Get("X-Header-D"))
})
}

View File

@@ -70,3 +70,16 @@ func TestForwardedCookiesMiddleware(t *testing.T) {
})
}
}
type testContext struct {
callChain []string
req *http.Request
}
func (c *testContext) createRoundTripper() http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
c.callChain = append(c.callChain, "final")
c.req = req
return &http.Response{StatusCode: http.StatusOK}, nil
})
}

View File

@@ -1,31 +0,0 @@
package httpclientprovider
import (
"fmt"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"golang.org/x/oauth2"
)
const ForwardedOAuthIdentityMiddlewareName = "forwarded-oauth-identity"
// ForwardedOAuthIdentityMiddleware middleware that sets Authorization/X-ID-Token
// headers on the outgoing request if an OAuth Token is provided
func ForwardedOAuthIdentityMiddleware(token *oauth2.Token) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(ForwardedOAuthIdentityMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if token == nil {
return next
}
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Header.Set("X-ID-Token", idToken)
}
return next.RoundTrip(req)
})
})
}

View File

@@ -1,82 +0,0 @@
package httpclientprovider_test
import (
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestForwardedOAuthIdentityMiddleware(t *testing.T) {
at := &oauth2.Token{
AccessToken: "access-token",
}
tcs := []struct {
desc string
token *oauth2.Token
expectedAuthorizationHeader string
expectedIDTokenHeader string
}{
{
desc: "With nil token should not populate Cookie headers",
token: nil,
expectedAuthorizationHeader: "",
expectedIDTokenHeader: "",
},
{
desc: "With access token set should populate Authorization header",
token: at,
expectedAuthorizationHeader: "Bearer access-token",
expectedIDTokenHeader: "",
},
{
desc: "With Authorization and X-ID-Token header set should populate Authorization and X-Id-Token header",
token: at.WithExtra(map[string]interface{}{"id_token": "id-token"}),
expectedAuthorizationHeader: "Bearer access-token",
expectedIDTokenHeader: "id-token",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper()
mw := httpclientprovider.ForwardedOAuthIdentityMiddleware(tc.token)
opts := httpclient.Options{}
rt := mw.CreateMiddleware(opts, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, "forwarded-oauth-identity", middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
require.Equal(t, tc.expectedAuthorizationHeader, ctx.req.Header.Get("Authorization"))
require.Equal(t, tc.expectedIDTokenHeader, ctx.req.Header.Get("X-ID-Token"))
})
}
}
type testContext struct {
callChain []string
req *http.Request
}
func (c *testContext) createRoundTripper() http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
c.callChain = append(c.callChain, "final")
c.req = req
return &http.Response{StatusCode: http.StatusOK}, nil
})
}

View File

@@ -25,10 +25,10 @@ func New(cfg *setting.Cfg, validator models.PluginRequestValidator, tracer traci
middlewares := []sdkhttpclient.Middleware{
TracingMiddleware(logger, tracer),
DataSourceMetricsMiddleware(),
sdkhttpclient.ContextualMiddleware(),
SetUserAgentMiddleware(userAgent),
sdkhttpclient.BasicAuthenticationMiddleware(),
sdkhttpclient.CustomHeadersMiddleware(),
sdkhttpclient.ContextualMiddleware(),
ResponseLimitMiddleware(cfg.ResponseLimit),
RedirectLimitMiddleware(validator),
}

View File

@@ -29,10 +29,10 @@ func TestHTTPClientProvider(t *testing.T) {
require.Len(t, o.Middlewares, 8)
require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
})
@@ -53,10 +53,10 @@ func TestHTTPClientProvider(t *testing.T) {
require.Len(t, o.Middlewares, 9)
require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SigV4MiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName())
})
@@ -78,10 +78,10 @@ func TestHTTPClientProvider(t *testing.T) {
require.Len(t, o.Middlewares, 9)
require.Equal(t, TracingMiddlewareName, o.Middlewares[0].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, DataSourceMetricsMiddlewareName, o.Middlewares[1].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ContextualMiddlewareName, o.Middlewares[2].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, HostRedirectValidationMiddlewareName, o.Middlewares[7].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, HTTPLoggerMiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName())

View File

@@ -0,0 +1,31 @@
package httpclientprovider
import (
"net/http"
"net/textproto"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
const SetHeadersMiddlewareName = "set-headers"
// SetHeadersMiddleware middleware that sets headers on the outgoing
// request if headers provided.
// If the request already contains any of the headers provided, they
// will be overwritten.
func SetHeadersMiddleware(headers http.Header) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(SetHeadersMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if len(headers) == 0 {
return next
}
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
for k, v := range headers {
canonicalKey := textproto.CanonicalMIMEHeaderKey(k)
req.Header[canonicalKey] = v
}
return next.RoundTrip(req)
})
})
}

View File

@@ -0,0 +1,66 @@
package httpclientprovider
import (
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/require"
)
func TestSetHeadersMiddleware(t *testing.T) {
t.Run("Without headers set should return next http.RoundTripper", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("finalrt")
var headers http.Header
mw := SetHeadersMiddleware(headers)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
})
t.Run("With headers set should apply HTTP headers to the request", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("final")
headers := http.Header{
"X-Header-A": []string{"value a"},
"X-Header-B": []string{"value b"},
"X-HEader-C": []string{"value c"},
}
mw := SetHeadersMiddleware(headers)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, SetHeadersMiddlewareName, middlewareName.MiddlewareName())
req, err := http.NewRequest(http.MethodGet, "http://", nil)
require.NoError(t, err)
req.Header.Set("X-Header-B", "d")
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
if res.Body != nil {
require.NoError(t, res.Body.Close())
}
require.Len(t, ctx.callChain, 1)
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
require.Equal(t, "value a", req.Header.Get("X-Header-A"))
require.Equal(t, "value b", req.Header.Get("X-Header-B"))
require.Equal(t, "value c", req.Header.Get("X-Header-C"))
})
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestCustomHeadersMiddleware(t *testing.T) {
func TestSetUserAgentMiddleware(t *testing.T) {
t.Run("Without user agent set should return next http.RoundTripper", func(t *testing.T) {
ctx := &testContext{}
finalRoundTripper := ctx.createRoundTripper("finalrt")

View File

@@ -79,3 +79,20 @@ type PluginLoaderAuthorizer interface {
type RoleRegistry interface {
DeclarePluginRoles(ctx context.Context, ID, name string, registrations []RoleRegistration) error
}
// ClientMiddleware is an interface representing the ability to create a middleware
// that implements the Client interface.
type ClientMiddleware interface {
// CreateClientMiddleware creates a new client middleware.
CreateClientMiddleware(next Client) Client
}
// The ClientMiddlewareFunc type is an adapter to allow the use of ordinary
// functions as ClientMiddleware's. If f is a function with the appropriate
// signature, ClientMiddlewareFunc(f) is a ClientMiddleware that calls f.
type ClientMiddlewareFunc func(next Client) Client
// CreateClientMiddleware implements the ClientMiddleware interface.
func (fn ClientMiddlewareFunc) CreateClientMiddleware(next Client) Client {
return fn(next)
}

View File

@@ -29,6 +29,10 @@ func ProvideService(pluginRegistry registry.Service, cfg *config.Cfg) *Service {
}
func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
plugin, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return nil, plugins.ErrPluginNotRegistered.Errorf("%w", backendplugin.ErrPluginNotRegistered)
@@ -65,6 +69,14 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
}
func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return fmt.Errorf("req cannot be nil")
}
if sender == nil {
return fmt.Errorf("sender cannot be nil")
}
p, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return backendplugin.ErrPluginNotRegistered
@@ -84,6 +96,10 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq
}
func (s *Service) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
p, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return nil, backendplugin.ErrPluginNotRegistered
@@ -102,6 +118,10 @@ func (s *Service) CollectMetrics(ctx context.Context, req *backend.CollectMetric
}
func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
p, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return nil, backendplugin.ErrPluginNotRegistered
@@ -129,6 +149,10 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque
}
func (s *Service) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
plugin, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return nil, backendplugin.ErrPluginNotRegistered
@@ -138,6 +162,10 @@ func (s *Service) SubscribeStream(ctx context.Context, req *backend.SubscribeStr
}
func (s *Service) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
plugin, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return nil, backendplugin.ErrPluginNotRegistered
@@ -147,6 +175,14 @@ func (s *Service) PublishStream(ctx context.Context, req *backend.PublishStreamR
}
func (s *Service) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
if req == nil {
return fmt.Errorf("req cannot be nil")
}
if sender == nil {
return fmt.Errorf("sender cannot be nil")
}
plugin, exists := s.plugin(ctx, req.PluginContext.PluginID)
if !exists {
return backendplugin.ErrPluginNotRegistered

View File

@@ -0,0 +1,205 @@
package clienttest
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/manager/client"
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/require"
)
type TestClient struct {
plugins.Client
QueryDataFunc backend.QueryDataHandlerFunc
CallResourceFunc backend.CallResourceHandlerFunc
CheckHealthFunc backend.CheckHealthHandlerFunc
}
func (c *TestClient) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if c.QueryDataFunc != nil {
return c.QueryDataFunc(ctx, req)
}
return nil, nil
}
func (c *TestClient) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if c.CallResourceFunc != nil {
return c.CallResourceFunc(ctx, req, sender)
}
return nil
}
func (c *TestClient) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if c.CheckHealthFunc != nil {
return c.CheckHealthFunc(ctx, req)
}
return nil, nil
}
type MiddlewareScenarioContext struct {
QueryDataCallChain []string
CallResourceCallChain []string
CollectMetricsCallChain []string
CheckHealthCallChain []string
SubscribeStreamCallChain []string
PublishStreamCallChain []string
RunStreamCallChain []string
}
func (ctx *MiddlewareScenarioContext) NewMiddleware(name string) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &TestMiddleware{
next: next,
Name: name,
sCtx: ctx,
}
})
}
type TestMiddleware struct {
next plugins.Client
sCtx *MiddlewareScenarioContext
Name string
}
func (m *TestMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.QueryData(ctx, req)
m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("before %s", m.Name))
err := m.next.CallResource(ctx, req, sender)
m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("after %s", m.Name))
return err
}
func (m *TestMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.CollectMetrics(ctx, req)
m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.CheckHealth(ctx, req)
m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.SubscribeStream(ctx, req)
m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.PublishStream(ctx, req)
m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("before %s", m.Name))
err := m.next.RunStream(ctx, req, sender)
m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("after %s", m.Name))
return err
}
var _ plugins.Client = &TestClient{}
type ClientDecoratorTest struct {
T *testing.T
Context context.Context
TestClient *TestClient
Middlewares []plugins.ClientMiddleware
Decorator *client.Decorator
ReqContext *models.ReqContext
QueryDataReq *backend.QueryDataRequest
QueryDataCtx context.Context
CallResourceReq *backend.CallResourceRequest
CallResourceCtx context.Context
CheckHealthReq *backend.CheckHealthRequest
CheckHealthCtx context.Context
}
type ClientDecoratorTestOption func(*ClientDecoratorTest)
func NewClientDecoratorTest(t *testing.T, opts ...ClientDecoratorTestOption) *ClientDecoratorTest {
cdt := &ClientDecoratorTest{
T: t,
Context: context.Background(),
}
cdt.TestClient = &TestClient{
QueryDataFunc: func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
cdt.QueryDataReq = req
cdt.QueryDataCtx = ctx
return nil, nil
},
CallResourceFunc: func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
cdt.CallResourceReq = req
cdt.CallResourceCtx = ctx
return nil
},
CheckHealthFunc: func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
cdt.CheckHealthReq = req
cdt.CheckHealthCtx = ctx
return nil, nil
},
}
require.NotNil(t, cdt)
for _, opt := range opts {
opt(cdt)
}
d, err := client.NewDecorator(cdt.TestClient, cdt.Middlewares...)
require.NoError(t, err)
require.NotNil(t, d)
cdt.Decorator = d
return cdt
}
func WithReqContext(req *http.Request, user *user.SignedInUser) ClientDecoratorTestOption {
return ClientDecoratorTestOption(func(cdt *ClientDecoratorTest) {
if cdt.ReqContext == nil {
cdt.ReqContext = &models.ReqContext{
Context: &web.Context{},
SignedInUser: user,
}
}
cdt.Context = ctxkey.Set(cdt.Context, cdt.ReqContext)
*req = *req.WithContext(cdt.Context)
cdt.ReqContext.Req = req
})
}
func WithMiddlewares(middlewares ...plugins.ClientMiddleware) ClientDecoratorTestOption {
return ClientDecoratorTestOption(func(cdt *ClientDecoratorTest) {
if cdt.Middlewares == nil {
cdt.Middlewares = []plugins.ClientMiddleware{}
}
cdt.Middlewares = append(cdt.Middlewares, middlewares...)
})
}

View File

@@ -0,0 +1,126 @@
package client
import (
"context"
"fmt"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/plugins"
)
// Decorator allows a plugins.Client to be decorated with middlewares.
type Decorator struct {
client plugins.Client
middlewares []plugins.ClientMiddleware
}
// NewDecorator creates a new plugins.client decorator.
func NewDecorator(client plugins.Client, middlewares ...plugins.ClientMiddleware) (*Decorator, error) {
if client == nil {
return nil, fmt.Errorf("client cannot be nil")
}
return &Decorator{
client: client,
middlewares: middlewares,
}, nil
}
func (d *Decorator) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.QueryData(ctx, req)
}
func (d *Decorator) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return fmt.Errorf("req cannot be nil")
}
if sender == nil {
return fmt.Errorf("sender cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.CallResource(ctx, req, sender)
}
func (d *Decorator) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.CollectMetrics(ctx, req)
}
func (d *Decorator) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.CheckHealth(ctx, req)
}
func (d *Decorator) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.SubscribeStream(ctx, req)
}
func (d *Decorator) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
if req == nil {
return nil, fmt.Errorf("req cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.PublishStream(ctx, req)
}
func (d *Decorator) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
if req == nil {
return fmt.Errorf("req cannot be nil")
}
if sender == nil {
return fmt.Errorf("sender cannot be nil")
}
client := clientFromMiddlewares(d.middlewares, d.client)
return client.RunStream(ctx, req, sender)
}
func clientFromMiddlewares(middlewares []plugins.ClientMiddleware, finalClient plugins.Client) plugins.Client {
if len(middlewares) == 0 {
return finalClient
}
reversed := reverseMiddlewares(middlewares)
next := finalClient
for _, m := range reversed {
next = m.CreateClientMiddleware(next)
}
return next
}
func reverseMiddlewares(middlewares []plugins.ClientMiddleware) []plugins.ClientMiddleware {
reversed := make([]plugins.ClientMiddleware, len(middlewares))
copy(reversed, middlewares)
for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 {
reversed[i], reversed[j] = reversed[j], reversed[i]
}
return reversed
}
var _ plugins.Client = &Decorator{}

View File

@@ -0,0 +1,229 @@
package client
import (
"context"
"fmt"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/plugins"
"github.com/stretchr/testify/require"
)
func TestDecorator(t *testing.T) {
var queryDataCalled bool
var callResourceCalled bool
var checkHealthCalled bool
c := &TestClient{
QueryDataFunc: func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
queryDataCalled = true
return nil, nil
},
CallResourceFunc: func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
callResourceCalled = true
return nil
},
CheckHealthFunc: func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
checkHealthCalled = true
return nil, nil
},
}
require.NotNil(t, c)
ctx := MiddlewareScenarioContext{}
mwOne := ctx.NewMiddleware("mw1")
mwTwo := ctx.NewMiddleware("mw2")
d, err := NewDecorator(c, mwOne, mwTwo)
require.NoError(t, err)
require.NotNil(t, d)
_, _ = d.QueryData(context.Background(), &backend.QueryDataRequest{})
require.True(t, queryDataCalled)
sender := callResourceResponseSenderFunc(func(res *backend.CallResourceResponse) error {
return nil
})
_ = d.CallResource(context.Background(), &backend.CallResourceRequest{}, sender)
require.True(t, callResourceCalled)
_, _ = d.CheckHealth(context.Background(), &backend.CheckHealthRequest{})
require.True(t, checkHealthCalled)
require.Len(t, ctx.QueryDataCallChain, 4)
require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.QueryDataCallChain)
require.Len(t, ctx.CallResourceCallChain, 4)
require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.CallResourceCallChain)
require.Len(t, ctx.CheckHealthCallChain, 4)
require.EqualValues(t, []string{"before mw1", "before mw2", "after mw2", "after mw1"}, ctx.CheckHealthCallChain)
}
func TestReverseMiddlewares(t *testing.T) {
t.Run("Should reverse 1 middleware", func(t *testing.T) {
ctx := MiddlewareScenarioContext{}
middlewares := []plugins.ClientMiddleware{
ctx.NewMiddleware("mw1"),
}
reversed := reverseMiddlewares(middlewares)
require.Len(t, reversed, 1)
require.Equal(t, "mw1", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name)
})
t.Run("Should reverse 2 middlewares", func(t *testing.T) {
ctx := MiddlewareScenarioContext{}
middlewares := []plugins.ClientMiddleware{
ctx.NewMiddleware("mw1"),
ctx.NewMiddleware("mw2"),
}
reversed := reverseMiddlewares(middlewares)
require.Len(t, reversed, 2)
require.Equal(t, "mw2", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw1", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name)
})
t.Run("Should reverse 3 middlewares", func(t *testing.T) {
ctx := MiddlewareScenarioContext{}
middlewares := []plugins.ClientMiddleware{
ctx.NewMiddleware("mw1"),
ctx.NewMiddleware("mw2"),
ctx.NewMiddleware("mw3"),
}
reversed := reverseMiddlewares(middlewares)
require.Len(t, reversed, 3)
require.Equal(t, "mw3", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw2", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw1", reversed[2].CreateClientMiddleware(nil).(*TestMiddleware).Name)
})
t.Run("Should reverse 4 middlewares", func(t *testing.T) {
ctx := MiddlewareScenarioContext{}
middlewares := []plugins.ClientMiddleware{
ctx.NewMiddleware("mw1"),
ctx.NewMiddleware("mw2"),
ctx.NewMiddleware("mw3"),
ctx.NewMiddleware("mw4"),
}
reversed := reverseMiddlewares(middlewares)
require.Len(t, reversed, 4)
require.Equal(t, "mw4", reversed[0].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw3", reversed[1].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw2", reversed[2].CreateClientMiddleware(nil).(*TestMiddleware).Name)
require.Equal(t, "mw1", reversed[3].CreateClientMiddleware(nil).(*TestMiddleware).Name)
})
}
type TestClient struct {
plugins.Client
QueryDataFunc backend.QueryDataHandlerFunc
CallResourceFunc backend.CallResourceHandlerFunc
CheckHealthFunc backend.CheckHealthHandlerFunc
}
func (c *TestClient) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if c.QueryDataFunc != nil {
return c.QueryDataFunc(ctx, req)
}
return nil, nil
}
func (c *TestClient) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if c.CallResourceFunc != nil {
return c.CallResourceFunc(ctx, req, sender)
}
return nil
}
func (c *TestClient) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if c.CheckHealthFunc != nil {
return c.CheckHealthFunc(ctx, req)
}
return nil, nil
}
type MiddlewareScenarioContext struct {
QueryDataCallChain []string
CallResourceCallChain []string
CollectMetricsCallChain []string
CheckHealthCallChain []string
SubscribeStreamCallChain []string
PublishStreamCallChain []string
RunStreamCallChain []string
}
func (ctx *MiddlewareScenarioContext) NewMiddleware(name string) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &TestMiddleware{
next: next,
Name: name,
sCtx: ctx,
}
})
}
type TestMiddleware struct {
next plugins.Client
sCtx *MiddlewareScenarioContext
Name string
}
func (m *TestMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.QueryData(ctx, req)
m.sCtx.QueryDataCallChain = append(m.sCtx.QueryDataCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("before %s", m.Name))
err := m.next.CallResource(ctx, req, sender)
m.sCtx.CallResourceCallChain = append(m.sCtx.CallResourceCallChain, fmt.Sprintf("after %s", m.Name))
return err
}
func (m *TestMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.CollectMetrics(ctx, req)
m.sCtx.CollectMetricsCallChain = append(m.sCtx.CollectMetricsCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.CheckHealth(ctx, req)
m.sCtx.CheckHealthCallChain = append(m.sCtx.CheckHealthCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.SubscribeStream(ctx, req)
m.sCtx.SubscribeStreamCallChain = append(m.sCtx.SubscribeStreamCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("before %s", m.Name))
res, err := m.next.PublishStream(ctx, req)
m.sCtx.PublishStreamCallChain = append(m.sCtx.PublishStreamCallChain, fmt.Sprintf("after %s", m.Name))
return res, err
}
func (m *TestMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("before %s", m.Name))
err := m.next.RunStream(ctx, req, sender)
m.sCtx.RunStreamCallChain = append(m.sCtx.RunStreamCallChain, fmt.Sprintf("after %s", m.Name))
return err
}
type callResourceResponseSenderFunc func(res *backend.CallResourceResponse) error
func (fn callResourceResponseSenderFunc) Send(res *backend.CallResourceResponse) error {
return fn(res)
}
var _ plugins.Client = &TestClient{}

View File

@@ -1,13 +1,32 @@
package server
import (
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
"github.com/grafana/grafana/pkg/services/grpcserver"
"github.com/grafana/grafana/pkg/services/notifications"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/sqlstore"
)
func ProvideTestEnv(server *Server, store *sqlstore.SQLStore, ns *notifications.NotificationServiceMock, grpcServer grpcserver.Provider) (*TestEnv, error) {
return &TestEnv{server, store, ns, grpcServer}, nil
func ProvideTestEnv(
server *Server,
store *sqlstore.SQLStore,
ns *notifications.NotificationServiceMock,
grpcServer grpcserver.Provider,
pluginRegistry registry.Service,
httpClientProvider httpclient.Provider,
oAuthTokenService *oauthtokentest.Service,
) (*TestEnv, error) {
return &TestEnv{
server,
store,
ns,
grpcServer,
pluginRegistry,
httpClientProvider,
oAuthTokenService,
}, nil
}
type TestEnv struct {
@@ -15,4 +34,7 @@ type TestEnv struct {
SQLStore *sqlstore.SQLStore
NotificationService *notifications.NotificationServiceMock
GRPCServer grpcserver.Provider
PluginRegistry registry.Service
HTTPClientProvider httpclient.Provider
OAuthTokenService *oauthtokentest.Service
}

View File

@@ -28,18 +28,7 @@ import (
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
pluginsCfg "github.com/grafana/grafana/pkg/plugins/config"
"github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/plugins/manager/client"
pluginDashboards "github.com/grafana/grafana/pkg/plugins/manager/dashboards"
"github.com/grafana/grafana/pkg/plugins/manager/loader"
processManager "github.com/grafana/grafana/pkg/plugins/manager/process"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
managerStore "github.com/grafana/grafana/pkg/plugins/manager/store"
"github.com/grafana/grafana/pkg/plugins/plugincontext"
"github.com/grafana/grafana/pkg/plugins/repo"
"github.com/grafana/grafana/pkg/registry/corekind"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
@@ -93,12 +82,14 @@ import (
ngstore "github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/services/notifications"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/org/orgimpl"
"github.com/grafana/grafana/pkg/services/playlist/playlistimpl"
"github.com/grafana/grafana/pkg/services/plugindashboards"
plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service"
"github.com/grafana/grafana/pkg/services/pluginsettings"
pluginSettings "github.com/grafana/grafana/pkg/services/pluginsettings/service"
"github.com/grafana/grafana/pkg/services/pluginsintegration"
"github.com/grafana/grafana/pkg/services/preference/prefimpl"
"github.com/grafana/grafana/pkg/services/publicdashboards"
publicdashboardsApi "github.com/grafana/grafana/pkg/services/publicdashboards/api"
@@ -192,28 +183,9 @@ var wireBasicSet = wire.NewSet(
updatechecker.ProvidePluginsService,
uss.ProvideService,
wire.Bind(new(usagestats.Service), new(*uss.UsageStats)),
registry.ProvideService,
wire.Bind(new(registry.Service), new(*registry.InMemory)),
pluginsCfg.ProvideConfig,
repo.ProvideService,
wire.Bind(new(repo.Service), new(*repo.Manager)),
manager.ProvideInstaller,
wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)),
client.ProvideService,
wire.Bind(new(plugins.Client), new(*client.Service)),
managerStore.ProvideService,
wire.Bind(new(plugins.Store), new(*managerStore.Service)),
wire.Bind(new(plugins.RendererManager), new(*managerStore.Service)),
wire.Bind(new(plugins.SecretsPluginManager), new(*managerStore.Service)),
wire.Bind(new(plugins.StaticRouteResolver), new(*managerStore.Service)),
pluginsintegration.WireSet,
pluginDashboards.ProvideFileStoreManager,
wire.Bind(new(pluginDashboards.FileStore), new(*pluginDashboards.FileStoreManager)),
processManager.ProvideService,
wire.Bind(new(processManager.Service), new(*processManager.Manager)),
coreplugin.ProvideCoreRegistry,
loader.ProvideService,
wire.Bind(new(loader.Service), new(*loader.Loader)),
wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)),
cloudwatch.ProvideService,
cloudmonitoring.ProvideService,
azuremonitor.ProvideService,
@@ -251,7 +223,6 @@ var wireBasicSet = wire.NewSet(
export.ProvideService,
live.ProvideService,
pushhttp.ProvideService,
plugincontext.ProvideService,
contexthandler.ProvideService,
jwt.ProvideService,
wire.Bind(new(models.JWTService), new(*jwt.AuthService)),
@@ -271,8 +242,6 @@ var wireBasicSet = wire.NewSet(
social.ProvideService,
influxdb.ProvideService,
wire.Bind(new(social.Service), new(*social.SocialService)),
oauthtoken.ProvideService,
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)),
tempo.ProvideService,
loki.ProvideService,
graphite.ProvideService,
@@ -391,6 +360,8 @@ var wireSet = wire.NewSet(
wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)),
wire.Bind(new(db.DB), new(*sqlstore.SQLStore)),
prefimpl.ProvideService,
oauthtoken.ProvideService,
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)),
)
var wireTestSet = wire.NewSet(
@@ -407,6 +378,9 @@ var wireTestSet = wire.NewSet(
wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)),
wire.Bind(new(db.DB), new(*sqlstore.SQLStore)),
prefimpl.ProvideService,
oauthtoken.ProvideService,
oauthtokentest.ProvideService,
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtokentest.Service)),
)
func Initialize(cla setting.CommandLineArgs, opts Options, apiOpts api.ServerOptions) (*Server, error) {

View File

@@ -8,8 +8,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/provider"
"github.com/grafana/grafana/pkg/plugins/manager/signature"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/server/backgroundsvcs"
"github.com/grafana/grafana/pkg/server/usagestatssvcs"
@@ -29,6 +27,7 @@ import (
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
"github.com/grafana/grafana/pkg/services/pluginsintegration"
"github.com/grafana/grafana/pkg/services/provisioning"
"github.com/grafana/grafana/pkg/services/searchusers"
"github.com/grafana/grafana/pkg/services/searchusers/filters"
@@ -71,10 +70,6 @@ var wireExtsBasicSet = wire.NewSet(
wire.Bind(new(user.SearchUserFilter), new(*filters.OSSSearchUserFilter)),
searchusers.ProvideUsersService,
wire.Bind(new(searchusers.Service), new(*searchusers.OSSService)),
signature.ProvideOSSAuthorizer,
wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)),
provider.ProvideService,
wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)),
osskmsproviders.ProvideService,
wire.Bind(new(kmsproviders.Service), new(osskmsproviders.Service)),
ldap.ProvideGroupsService,
@@ -85,6 +80,7 @@ var wireExtsBasicSet = wire.NewSet(
wire.Bind(new(registry.UsageStatsProvidersRegistry), new(*usagestatssvcs.UsageStatsProvidersRegistry)),
ossaccesscontrol.ProvideDatasourcePermissionsService,
wire.Bind(new(accesscontrol.DatasourcePermissionsService), new(*ossaccesscontrol.DatasourcePermissionsService)),
pluginsintegration.WireExtensionSet,
)
var wireExtsSet = wire.NewSet(

View File

@@ -79,7 +79,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
// IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source.
func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
return IsOAuthPassThruEnabled(ds)
}
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
@@ -204,6 +204,11 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.Us
return token, nil
}
// IsOAuthPassThruEnabled returns true if Forward OAuth Identity (oauthPassThru) is enabled for the provided data source.
func IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
}
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
func tokensEq(t1, t2 *oauth2.Token) bool {
return t1.AccessToken == t2.AccessToken &&

View File

@@ -0,0 +1,41 @@
package oauthtokentest
import (
"context"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"golang.org/x/oauth2"
)
// Service an OAuth token service suitable for tests.
type Service struct {
Token *oauth2.Token
}
// ProvideService provides an OAuth token service suitable for tests.
func ProvideService() *Service {
return &Service{}
}
func (s *Service) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
return s.Token
}
func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return oauthtoken.IsOAuthPassThruEnabled(ds)
}
func (s *Service) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (s *Service) TryTokenRefresh(context.Context, *models.UserAuth) error {
return nil
}
func (s *Service) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
return nil
}

View File

@@ -0,0 +1,87 @@
package clientmiddleware
import (
"context"
"github.com/grafana/grafana-plugin-sdk-go/backend"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
)
// NewClearAuthHeadersMiddleware creates a new plugins.ClientMiddleware
// that will clear any outgoing HTTP headers that was part of the incoming
// HTTP request and used when authenticating to Grafana.
func NewClearAuthHeadersMiddleware() plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &ClearAuthHeadersMiddleware{
next: next,
}
})
}
type ClearAuthHeadersMiddleware struct {
next plugins.Client
}
func (m *ClearAuthHeadersMiddleware) clearHeaders(ctx context.Context, pCtx backend.PluginContext, req interface{}) context.Context {
reqCtx := contexthandler.FromContext(ctx)
// if no HTTP request context skip middleware
if req == nil || reqCtx == nil || reqCtx.Req == nil || reqCtx.SignedInUser == nil {
return ctx
}
list := contexthandler.AuthHTTPHeaderListFromContext(ctx)
if list != nil {
ctx = sdkhttpclient.WithContextualMiddleware(ctx, httpclientprovider.DeleteHeadersMiddleware(list.Items...))
}
return ctx
}
func (m *ClearAuthHeadersMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return m.next.QueryData(ctx, req)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
return m.next.QueryData(ctx, req)
}
func (m *ClearAuthHeadersMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return m.next.CallResource(ctx, req, sender)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
return m.next.CallResource(ctx, req, sender)
}
func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return m.next.CheckHealth(ctx, req)
}
ctx = m.clearHeaders(ctx, req.PluginContext, req)
return m.next.CheckHealth(ctx, req)
}
func (m *ClearAuthHeadersMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
return m.next.CollectMetrics(ctx, req)
}
func (m *ClearAuthHeadersMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
return m.next.SubscribeStream(ctx, req)
}
func (m *ClearAuthHeadersMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
return m.next.PublishStream(ctx, req)
}
func (m *ClearAuthHeadersMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
return m.next.RunStream(ctx, req, sender)
}

View File

@@ -0,0 +1,310 @@
package clientmiddleware
import (
"bytes"
"io"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/require"
)
func TestClearAuthHeadersMiddleware(t *testing.T) {
const otherHeader = "test"
t.Run("When no auth headers in reqContext", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
req.Header.Set(otherHeader, "test")
t.Run("And requests are for a datasource", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
}
t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
t.Run("And requests are for an app", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
)
pluginCtx := backend.PluginContext{
AppInstanceSettings: &backend.AppInstanceSettings{},
}
t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
})
t.Run("When auth headers in reqContext", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
t.Run("And requests are for a datasource", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
)
const customHeader = "X-Custom"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)
const otherHeader = "X-Other"
req.Header.Set(otherHeader, "test")
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{},
}
t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
})
t.Run("And requests are for an app", func(t *testing.T) {
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()),
)
const customHeader = "X-Custom"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)
const otherHeader = "X-Other"
req.Header.Set(otherHeader, "test")
pluginCtx := backend.PluginContext{
AppInstanceSettings: &backend.AppInstanceSettings{},
}
t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.DeleteHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Empty(t, reqClone.Header[customHeader])
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
})
})
}
var finalRoundTripper = httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Request: req,
Body: io.NopCloser(bytes.NewBufferString("")),
}, nil
})

View File

@@ -0,0 +1,126 @@
package clientmiddleware
import (
"context"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/util/proxyutil"
)
const cookieHeaderName = "Cookie"
// NewCookiesMiddleware creates a new plugins.ClientMiddleware that will
// forward incoming HTTP request Cookies to outgoing plugins.Client and
// HTTP requests if the datasource has enabled forwarding of cookies (keepCookies).
func NewCookiesMiddleware(skipCookiesNames []string) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &CookiesMiddleware{
next: next,
skipCookiesNames: skipCookiesNames,
}
})
}
type CookiesMiddleware struct {
next plugins.Client
skipCookiesNames []string
}
func (m *CookiesMiddleware) applyCookies(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
reqCtx := contexthandler.FromContext(ctx)
// if request not for a datasource or no HTTP request context skip middleware
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
return ctx, nil
}
settings := pCtx.DataSourceInstanceSettings
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
if err != nil {
return ctx, err
}
ds := &datasources.DataSource{
Id: settings.ID,
OrgId: pCtx.OrgID,
JsonData: jsonDataBytes,
Updated: settings.Updated,
}
proxyutil.ClearCookieHeader(reqCtx.Req, ds.AllowedCookies(), m.skipCookiesNames)
if cookieStr := reqCtx.Req.Header.Get(cookieHeaderName); cookieStr != "" {
switch t := req.(type) {
case *backend.QueryDataRequest:
t.Headers[cookieHeaderName] = cookieStr
case *backend.CheckHealthRequest:
t.Headers[cookieHeaderName] = cookieStr
case *backend.CallResourceRequest:
t.Headers[cookieHeaderName] = []string{cookieStr}
}
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(reqCtx.Req.Cookies(), ds.AllowedCookies(), m.skipCookiesNames))
return ctx, nil
}
func (m *CookiesMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return m.next.QueryData(ctx, req)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.QueryData(newCtx, req)
}
func (m *CookiesMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return m.next.CallResource(ctx, req, sender)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return err
}
return m.next.CallResource(newCtx, req, sender)
}
func (m *CookiesMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return m.next.CheckHealth(ctx, req)
}
newCtx, err := m.applyCookies(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.CheckHealth(newCtx, req)
}
func (m *CookiesMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
return m.next.CollectMetrics(ctx, req)
}
func (m *CookiesMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
return m.next.SubscribeStream(ctx, req)
}
func (m *CookiesMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
return m.next.PublishStream(ctx, req)
}
func (m *CookiesMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
return m.next.RunStream(ctx, req, sender)
}

View File

@@ -0,0 +1,223 @@
package clientmiddleware
import (
"encoding/json"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/require"
)
func TestCookiesMiddleware(t *testing.T) {
const otherHeader = "test"
t.Run("When keepCookies not configured for a datasource", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.Header.Set(otherHeader, "test")
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewCookiesMiddleware([]string{"grafana_session"})),
)
jsonDataMap := map[string]interface{}{}
jsonDataBytes, err := json.Marshal(&jsonDataMap)
require.NoError(t, err)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{
JSONData: jsonDataBytes,
},
}
t.Run("Should not forward cookies when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should not forward cookies when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
t.Run("Should not forward cookies when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 1)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
})
})
t.Run("When keepCookies configured for a datasource", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
req.Header.Set(otherHeader, "test")
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewCookiesMiddleware([]string{"grafana_session"})),
)
jsonDataMap := map[string]interface{}{
"keepCookies": []string{"cookie2", "grafana_session"},
}
jsonDataBytes, err := json.Marshal(&jsonDataMap)
require.NoError(t, err)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{
JSONData: jsonDataBytes,
},
}
t.Run("Should forward cookies when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 2)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
require.EqualValues(t, "cookie2=", cdt.QueryDataReq.Headers[cookieHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
t.Run("Should forward cookies when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 2)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
require.Len(t, cdt.CallResourceReq.Headers[cookieHeaderName], 1)
require.EqualValues(t, "cookie2=", cdt.CallResourceReq.Headers[cookieHeaderName][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
t.Run("Should forward cookies when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 2)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
require.EqualValues(t, "cookie2=", cdt.CheckHealthReq.Headers[cookieHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 2)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "cookie2=", reqClone.Header.Get(cookieHeaderName))
})
})
}

View File

@@ -0,0 +1,155 @@
package clientmiddleware
import (
"context"
"fmt"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
)
// NewOAuthTokenMiddleware creates a new plugins.ClientMiddleware that will
// set OAuth token headers on outgoing plugins.Client and HTTP requests if
// the datasource has enabled Forward OAuth Identity (oauthPassThru).
func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plugins.ClientMiddleware {
return plugins.ClientMiddlewareFunc(func(next plugins.Client) plugins.Client {
return &OAuthTokenMiddleware{
next: next,
oAuthTokenService: oAuthTokenService,
}
})
}
const (
tokenHeaderName = "Authorization"
idTokenHeaderName = "X-ID-Token"
)
type OAuthTokenMiddleware struct {
oAuthTokenService oauthtoken.OAuthTokenService
next plugins.Client
}
func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.PluginContext, req interface{}) (context.Context, error) {
reqCtx := contexthandler.FromContext(ctx)
// if request not for a datasource or no HTTP request context skip middleware
if req == nil || pCtx.DataSourceInstanceSettings == nil || reqCtx == nil || reqCtx.Req == nil {
return ctx, nil
}
settings := pCtx.DataSourceInstanceSettings
jsonDataBytes, err := simplejson.NewJson(settings.JSONData)
if err != nil {
return ctx, err
}
ds := &datasources.DataSource{
Id: settings.ID,
OrgId: pCtx.OrgID,
JsonData: jsonDataBytes,
Updated: settings.Updated,
}
if m.oAuthTokenService.IsOAuthPassThruEnabled(ds) {
if token := m.oAuthTokenService.GetCurrentOAuthToken(ctx, reqCtx.SignedInUser); token != nil {
authorizationHeader := fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
idTokenHeader := ""
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
idTokenHeader = idToken
}
switch t := req.(type) {
case *backend.QueryDataRequest:
t.Headers[tokenHeaderName] = authorizationHeader
if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = idTokenHeader
}
case *backend.CheckHealthRequest:
t.Headers[tokenHeaderName] = authorizationHeader
if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = idTokenHeader
}
case *backend.CallResourceRequest:
t.Headers[tokenHeaderName] = []string{authorizationHeader}
if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = []string{idTokenHeader}
}
}
httpHeaders := http.Header{}
httpHeaders.Set(tokenHeaderName, authorizationHeader)
if idTokenHeader != "" {
httpHeaders.Set(idTokenHeaderName, idTokenHeader)
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.SetHeadersMiddleware(httpHeaders))
}
}
return ctx, nil
}
func (m *OAuthTokenMiddleware) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if req == nil {
return m.next.QueryData(ctx, req)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.QueryData(newCtx, req)
}
func (m *OAuthTokenMiddleware) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if req == nil {
return m.next.CallResource(ctx, req, sender)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return err
}
return m.next.CallResource(newCtx, req, sender)
}
func (m *OAuthTokenMiddleware) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if req == nil {
return m.next.CheckHealth(ctx, req)
}
newCtx, err := m.applyToken(ctx, req.PluginContext, req)
if err != nil {
return nil, err
}
return m.next.CheckHealth(newCtx, req)
}
func (m *OAuthTokenMiddleware) CollectMetrics(ctx context.Context, req *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
return m.next.CollectMetrics(ctx, req)
}
func (m *OAuthTokenMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
return m.next.SubscribeStream(ctx, req)
}
func (m *OAuthTokenMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
return m.next.PublishStream(ctx, req)
}
func (m *OAuthTokenMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
return m.next.RunStream(ctx, req, sender)
}

View File

@@ -0,0 +1,197 @@
package clientmiddleware
import (
"encoding/json"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/plugins/manager/client/clienttest"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestOAuthTokenMiddleware(t *testing.T) {
const otherHeader = "test"
t.Run("When oauthPassThru not configured for a datasource", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
req.Header.Set(otherHeader, "test")
oAuthTokenService := &oauthtokentest.Service{}
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewOAuthTokenMiddleware(oAuthTokenService)),
)
jsonDataMap := map[string]interface{}{}
jsonDataBytes, err := json.Marshal(&jsonDataMap)
require.NoError(t, err)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{
JSONData: jsonDataBytes,
},
}
t.Run("Should not forward OAuth Identity when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 1)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not forward OAuth Identity when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 1)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 0)
})
t.Run("Should not forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 1)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 0)
})
})
t.Run("When oauthPassThru configured for a datasource", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/some/thing", nil)
require.NoError(t, err)
req.Header.Set(otherHeader, "test")
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
oAuthTokenService := &oauthtokentest.Service{
Token: token,
}
cdt := clienttest.NewClientDecoratorTest(t,
clienttest.WithReqContext(req, &user.SignedInUser{}),
clienttest.WithMiddlewares(NewOAuthTokenMiddleware(oAuthTokenService)),
)
jsonDataMap := map[string]interface{}{
"oauthPassThru": true,
}
jsonDataBytes, err := json.Marshal(&jsonDataMap)
require.NoError(t, err)
pluginCtx := backend.PluginContext{
DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{
JSONData: jsonDataBytes,
},
}
t.Run("Should forward OAuth Identity when calling QueryData", func(t *testing.T) {
_, err = cdt.Decorator.QueryData(req.Context(), &backend.QueryDataRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 3)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName])
require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.QueryDataCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) {
err = cdt.Decorator.CallResource(req.Context(), &backend.CallResourceRequest{
PluginContext: pluginCtx,
Headers: map[string][]string{otherHeader: {"test"}},
}, nopCallResourceSender)
require.NoError(t, err)
require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 3)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
require.Len(t, cdt.CallResourceReq.Headers[tokenHeaderName], 1)
require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0])
require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1)
require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CallResourceCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
_, err = cdt.Decorator.CheckHealth(req.Context(), &backend.CheckHealthRequest{
PluginContext: pluginCtx,
Headers: map[string]string{otherHeader: "test"},
})
require.NoError(t, err)
require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 3)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName])
require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName])
middlewares := httpclient.ContextualMiddlewareFromContext(cdt.CheckHealthCtx)
require.Len(t, middlewares, 1)
require.Equal(t, httpclientprovider.SetHeadersMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
reqClone := req.Clone(req.Context())
res, err := middlewares[0].CreateMiddleware(httpclient.Options{}, finalRoundTripper).RoundTrip(reqClone)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Len(t, reqClone.Header, 3)
require.Equal(t, "test", reqClone.Header.Get(otherHeader))
require.Equal(t, "Bearer access-token", reqClone.Header.Get(tokenHeaderName))
require.Equal(t, "id-token", reqClone.Header.Get(idTokenHeaderName))
})
})
}

View File

@@ -0,0 +1,13 @@
package clientmiddleware
import "github.com/grafana/grafana-plugin-sdk-go/backend"
type callResourceResponseSenderFunc func(res *backend.CallResourceResponse) error
func (fn callResourceResponseSenderFunc) Send(res *backend.CallResourceResponse) error {
return fn(res)
}
var nopCallResourceSender = callResourceResponseSenderFunc(func(res *backend.CallResourceResponse) error {
return nil
})

View File

@@ -0,0 +1,4 @@
// package pluginsintegration instantiate the plugins
// package (pkg/plugins) and all of its services/dependencies that
// Grafana needs to provide plugin support.
package pluginsintegration

View File

@@ -0,0 +1,81 @@
package pluginsintegration
import (
"github.com/google/wire"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
"github.com/grafana/grafana/pkg/plugins/backendplugin/provider"
"github.com/grafana/grafana/pkg/plugins/config"
"github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/plugins/manager/client"
"github.com/grafana/grafana/pkg/plugins/manager/loader"
"github.com/grafana/grafana/pkg/plugins/manager/process"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
"github.com/grafana/grafana/pkg/plugins/manager/signature"
"github.com/grafana/grafana/pkg/plugins/manager/store"
"github.com/grafana/grafana/pkg/plugins/plugincontext"
"github.com/grafana/grafana/pkg/plugins/repo"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/pluginsintegration/clientmiddleware"
"github.com/grafana/grafana/pkg/setting"
)
// WireSet provides a wire.ProviderSet of plugin providers.
var WireSet = wire.NewSet(
config.ProvideConfig,
store.ProvideService,
wire.Bind(new(plugins.Store), new(*store.Service)),
wire.Bind(new(plugins.RendererManager), new(*store.Service)),
wire.Bind(new(plugins.SecretsPluginManager), new(*store.Service)),
wire.Bind(new(plugins.StaticRouteResolver), new(*store.Service)),
ProvideClientDecorator,
wire.Bind(new(plugins.Client), new(*client.Decorator)),
process.ProvideService,
wire.Bind(new(process.Service), new(*process.Manager)),
coreplugin.ProvideCoreRegistry,
loader.ProvideService,
wire.Bind(new(loader.Service), new(*loader.Loader)),
wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)),
manager.ProvideInstaller,
wire.Bind(new(plugins.Installer), new(*manager.PluginInstaller)),
registry.ProvideService,
wire.Bind(new(registry.Service), new(*registry.InMemory)),
repo.ProvideService,
wire.Bind(new(repo.Service), new(*repo.Manager)),
plugincontext.ProvideService,
)
// WireExtensionSet provides a wire.ProviderSet of plugin providers that can be
// extended.
var WireExtensionSet = wire.NewSet(
provider.ProvideService,
wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)),
signature.ProvideOSSAuthorizer,
wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)),
)
func ProvideClientDecorator(cfg *setting.Cfg, pCfg *config.Cfg,
pluginRegistry registry.Service,
oAuthTokenService oauthtoken.OAuthTokenService) (*client.Decorator, error) {
return NewClientDecorator(cfg, pCfg, pluginRegistry, oAuthTokenService)
}
func NewClientDecorator(cfg *setting.Cfg, pCfg *config.Cfg,
pluginRegistry registry.Service,
oAuthTokenService oauthtoken.OAuthTokenService) (*client.Decorator, error) {
c := client.ProvideService(pluginRegistry, pCfg)
middlewares := CreateMiddlewares(cfg, oAuthTokenService)
return client.NewDecorator(c, middlewares...)
}
func CreateMiddlewares(cfg *setting.Cfg, oAuthTokenService oauthtoken.OAuthTokenService) []plugins.ClientMiddleware {
skipCookiesNames := []string{cfg.LoginCookieName}
middlewares := []plugins.ClientMiddleware{
clientmiddleware.NewClearAuthHeadersMiddleware(),
clientmiddleware.NewOAuthTokenMiddleware(oAuthTokenService),
clientmiddleware.NewCookiesMiddleware(skipCookiesNames),
}
return middlewares
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/db"
@@ -137,7 +136,6 @@ func buildQueryDataService(t *testing.T, cs datasources.CacheService, fpc *fakeP
&fakePluginRequestValidator{},
&fakeDatasources.FakeDataSourceService{},
fpc,
&fakeOAuthTokenService{},
)
}
@@ -150,31 +148,6 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request)
return rv.err
}
type fakeOAuthTokenService struct {
passThruEnabled bool
token *oauth2.Token
}
func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
return ts.token
}
func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return nil
}
// copied from pkg/api/plugins_test.go
type fakePluginClient struct {
plugins.Client

View File

@@ -0,0 +1,90 @@
package query
import (
"context"
"strings"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
)
type parsedQuery struct {
datasource *datasources.DataSource
query backend.DataQuery
rawQuery *simplejson.Json
}
type parsedRequest struct {
hasExpression bool
parsedQueries map[string][]parsedQuery
dsTypes map[string]bool
}
func (pr parsedRequest) getFlattenedQueries() []parsedQuery {
queries := make([]parsedQuery, 0)
for _, pq := range pr.parsedQueries {
queries = append(queries, pq...)
}
return queries
}
func (pr parsedRequest) validateRequest(ctx context.Context) error {
reqCtx := contexthandler.FromContext(ctx)
if reqCtx == nil || reqCtx.Req == nil {
return nil
}
httpReq := reqCtx.Req
if pr.hasExpression {
hasExpr := httpReq.URL.Query().Get("expression")
if hasExpr == "" || hasExpr == "true" {
return nil
}
return ErrQueryParamMismatch
}
vals := splitHeaders(httpReq.Header.Values(HeaderDatasourceUID))
count := len(vals)
if count > 0 { // header exists
if count != len(pr.parsedQueries) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if pr.parsedQueries[t] == nil {
return ErrQueryParamMismatch
}
}
}
vals = splitHeaders(httpReq.Header.Values(HeaderPluginID))
count = len(vals)
if count > 0 { // header exists
if count != len(pr.dsTypes) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if !pr.dsTypes[t] {
return ErrQueryParamMismatch
}
}
}
return nil
}
func splitHeaders(headers []string) []string {
out := []string{}
for _, v := range headers {
if strings.Contains(v, ",") {
for _, sub := range strings.Split(v, ",") {
out = append(out, strings.TrimSpace(sub))
}
} else {
out = append(out, v)
}
}
return out
}

View File

@@ -5,26 +5,21 @@ import (
"fmt"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/adapters"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/grafanads"
"github.com/grafana/grafana/pkg/tsdb/legacydata"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/grafana/grafana/pkg/util/proxyutil"
"golang.org/x/sync/errgroup"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
const (
@@ -41,7 +36,6 @@ func ProvideService(
pluginRequestValidator models.PluginRequestValidator,
dataSourceService datasources.DataSourceService,
pluginClient plugins.Client,
oAuthTokenService oauthtoken.OAuthTokenService,
) *Service {
g := &Service{
cfg: cfg,
@@ -50,7 +44,6 @@ func ProvideService(
pluginRequestValidator: pluginRequestValidator,
dataSourceService: dataSourceService,
pluginClient: pluginClient,
oAuthTokenService: oAuthTokenService,
log: log.New("query_data"),
}
g.log.Info("Query Service initialization")
@@ -64,7 +57,6 @@ type Service struct {
pluginRequestValidator models.PluginRequestValidator
dataSourceService datasources.DataSourceService
pluginClient plugins.Client
oAuthTokenService oauthtoken.OAuthTokenService
log log.Logger
}
@@ -175,9 +167,6 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser
exprReq.OrgId = user.OrgID
}
disallowedCookies := []string{s.cfg.LoginCookieName}
queryEnrichers := parsedReq.createDataSourceQueryEnrichers(ctx, user, s.oAuthTokenService, disallowedCookies)
for _, pq := range parsedReq.getFlattenedQueries() {
if pq.datasource == nil {
return nil, ErrMissingDataSourceInfo.Build(errutil.TemplateData{
@@ -198,7 +187,6 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser
From: pq.query.TimeRange.From,
To: pq.query.TimeRange.To,
},
QueryEnricher: queryEnrichers[pq.datasource.Uid],
})
}
@@ -240,39 +228,10 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si
Queries: []backend.DataQuery{},
}
disallowedCookies := []string{s.cfg.LoginCookieName}
middlewares := []httpclient.Middleware{}
if parsedReq.httpRequest != nil {
middlewares = append(middlewares,
httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), disallowedCookies),
)
}
if s.oAuthTokenService.IsOAuthPassThruEnabled(ds) {
if token := s.oAuthTokenService.GetCurrentOAuthToken(ctx, user); token != nil {
req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Headers["X-ID-Token"] = idToken
}
middlewares = append(middlewares, httpclientprovider.ForwardedOAuthIdentityMiddleware(token))
}
}
if parsedReq.httpRequest != nil {
proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), disallowedCookies)
if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" {
req.Headers["Cookie"] = cookieStr
}
}
for _, q := range queries {
req.Queries = append(req.Queries, q.query)
}
ctx = httpclient.WithContextualMiddleware(ctx, middlewares...)
return s.pluginClient.QueryData(ctx, req)
}
@@ -335,11 +294,7 @@ func (s *Service) parseMetricRequest(ctx context.Context, user *user.SignedInUse
})
}
if reqDTO.HTTPRequest != nil {
req.httpRequest = reqDTO.HTTPRequest
}
_ = req.validateRequest()
_ = req.validateRequest(ctx)
return req, nil // TODO req.validateRequest()
}

View File

@@ -1,157 +0,0 @@
package query
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/proxyutil"
"golang.org/x/oauth2"
)
type parsedQuery struct {
datasource *datasources.DataSource
query backend.DataQuery
rawQuery *simplejson.Json
}
type parsedRequest struct {
hasExpression bool
parsedQueries map[string][]parsedQuery
dsTypes map[string]bool
httpRequest *http.Request
}
func (pr parsedRequest) getFlattenedQueries() []parsedQuery {
queries := make([]parsedQuery, 0)
for _, pq := range pr.parsedQueries {
queries = append(queries, pq...)
}
return queries
}
func (pr parsedRequest) validateRequest() error {
if pr.httpRequest == nil {
return nil
}
if pr.hasExpression {
hasExpr := pr.httpRequest.URL.Query().Get("expression")
if hasExpr == "" || hasExpr == "true" {
return nil
}
return ErrQueryParamMismatch
}
vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID))
count := len(vals)
if count > 0 { // header exists
if count != len(pr.parsedQueries) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if pr.parsedQueries[t] == nil {
return ErrQueryParamMismatch
}
}
}
vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID))
count = len(vals)
if count > 0 { // header exists
if count != len(pr.dsTypes) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if !pr.dsTypes[t] {
return ErrQueryParamMismatch
}
}
}
return nil
}
func (pr parsedRequest) createDataSourceQueryEnrichers(ctx context.Context, signedInUser *user.SignedInUser, oAuthTokenService oauthtoken.OAuthTokenService, disallowedCookies []string) map[string]expr.QueryDataRequestEnricher {
datasourcesHeaderProvider := map[string]expr.QueryDataRequestEnricher{}
if pr.httpRequest == nil {
return datasourcesHeaderProvider
}
for uid, queries := range pr.parsedQueries {
if expr.IsDataSource(uid) {
continue
}
if len(queries) == 0 || queries[0].datasource == nil {
continue
}
if _, exists := datasourcesHeaderProvider[uid]; exists {
continue
}
ds := queries[0].datasource
allowedCookies := ds.AllowedCookies()
clonedReq := pr.httpRequest.Clone(pr.httpRequest.Context())
var token *oauth2.Token
if oAuthTokenService.IsOAuthPassThruEnabled(ds) {
token = oAuthTokenService.GetCurrentOAuthToken(ctx, signedInUser)
}
datasourcesHeaderProvider[uid] = func(ctx context.Context, req *backend.QueryDataRequest) context.Context {
if len(req.Headers) == 0 {
req.Headers = map[string]string{}
}
if len(allowedCookies) > 0 {
proxyutil.ClearCookieHeader(clonedReq, allowedCookies, disallowedCookies)
if cookieStr := clonedReq.Header.Get("Cookie"); cookieStr != "" {
req.Headers["Cookie"] = cookieStr
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(clonedReq.Cookies(), allowedCookies, disallowedCookies))
}
if token != nil {
req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Headers["X-ID-Token"] = idToken
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedOAuthIdentityMiddleware(token))
}
return ctx
}
}
return datasourcesHeaderProvider
}
func splitHeaders(headers []string) []string {
out := []string{}
for _, v := range headers {
if strings.Contains(v, ",") {
for _, sub := range strings.Split(v, ",") {
out = append(out, strings.TrimSpace(sub))
}
} else {
out = append(out, v)
}
}
return out
}

View File

@@ -5,24 +5,22 @@ import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/models/roletype"
"github.com/grafana/grafana/pkg/plugins"
acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock"
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
"github.com/grafana/grafana/pkg/services/datasources"
fakeDatasources "github.com/grafana/grafana/pkg/services/datasources/fakes"
dsSvc "github.com/grafana/grafana/pkg/services/datasources/service"
@@ -33,35 +31,12 @@ import (
secretsmng "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
func TestParseMetricRequest(t *testing.T) {
t.Run("Test a simple single datasource query", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{
"refId": "A",
"datasource": {
@@ -82,61 +57,10 @@ func TestParseMetricRequest(t *testing.T) {
assert.Len(t, parsedReq.parsedQueries, 1)
assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz")
assert.Len(t, parsedReq.getFlattenedQueries(), 2)
t.Run("createDataSourceQueryEnrichers should return 0 enrichers when no HTTP request", func(t *testing.T) {
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{})
require.Empty(t, enrichers)
})
t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 1)
require.NotNil(t, enrichers["gIEkMvIVz"])
req := &backend.QueryDataRequest{}
ctx := enrichers["gIEkMvIVz"](context.Background(), req)
require.Len(t, req.Headers, 3)
require.Equal(t, "Bearer access-token", req.Headers["Authorization"])
require.Equal(t, "id-token", req.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"])
middlewares := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewares, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName())
})
})
t.Run("Test a single datasource query with expressions", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{
"refId": "A",
"datasource": {
@@ -156,75 +80,27 @@ func TestParseMetricRequest(t *testing.T) {
parsedReq, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr)
require.NoError(t, err)
require.NotNil(t, parsedReq)
assert.True(t, parsedReq.hasExpression)
assert.Len(t, parsedReq.parsedQueries, 2)
assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz")
assert.Len(t, parsedReq.getFlattenedQueries(), 2)
require.True(t, parsedReq.hasExpression)
require.Len(t, parsedReq.parsedQueries, 2)
require.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz")
require.Len(t, parsedReq.getFlattenedQueries(), 2)
// Make sure we end up with something valid
_, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq)
assert.NoError(t, err)
require.NoError(t, err)
t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 1)
require.NotNil(t, enrichers["gIEkMvIVz"])
req := &backend.QueryDataRequest{}
ctx := enrichers["gIEkMvIVz"](context.Background(), req)
require.Len(t, req.Headers, 3)
require.Equal(t, "Bearer access-token", req.Headers["Authorization"])
require.Equal(t, "id-token", req.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"])
middlewares := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewares, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName())
t.Run("Should forward user and org ID to QueryData from expression request", func(t *testing.T) {
require.NotNil(t, tc.pluginContext.req)
require.NotNil(t, tc.pluginContext.req.PluginContext.User)
require.Equal(t, tc.signedInUser.Login, tc.pluginContext.req.PluginContext.User.Login)
require.Equal(t, tc.signedInUser.Name, tc.pluginContext.req.PluginContext.User.Name)
require.Equal(t, tc.signedInUser.Email, tc.pluginContext.req.PluginContext.User.Email)
require.Equal(t, string(tc.signedInUser.OrgRole), tc.pluginContext.req.PluginContext.User.Role)
require.Equal(t, tc.signedInUser.OrgID, tc.pluginContext.req.PluginContext.OrgID)
})
})
t.Run("Test a simple mixed datasource query", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
json2, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie2" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
if datasourceUID == "sEx6ZvSVk" {
return &datasources.DataSource{
Uid: "sEx6ZvSVk",
JsonData: json2,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{
"refId": "A",
"datasource": {
@@ -254,43 +130,6 @@ func TestParseMetricRequest(t *testing.T) {
assert.Contains(t, parsedReq.parsedQueries, "sEx6ZvSVk")
assert.Len(t, parsedReq.parsedQueries["sEx6ZvSVk"], 2)
assert.Len(t, parsedReq.getFlattenedQueries(), 3)
t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 2)
enricherOne := enrichers["gIEkMvIVz"]
require.NotNil(t, enricherOne)
reqOne := &backend.QueryDataRequest{}
ctx := enricherOne(context.Background(), reqOne)
require.Len(t, reqOne.Headers, 3)
require.Equal(t, "Bearer access-token", reqOne.Headers["Authorization"])
require.Equal(t, "id-token", reqOne.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", reqOne.Headers["Cookie"])
middlewaresOne := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewaresOne, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresOne[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresOne[1].(httpclient.MiddlewareName).MiddlewareName())
enricherTwo := enrichers["sEx6ZvSVk"]
require.NotNil(t, enricherTwo)
reqTwo := &backend.QueryDataRequest{}
ctx = enricherTwo(context.Background(), reqTwo)
require.Len(t, reqTwo.Headers, 3)
require.Equal(t, "Bearer access-token", reqTwo.Headers["Authorization"])
require.Equal(t, "id-token", reqTwo.Headers["X-ID-Token"])
require.Equal(t, "cookie2=", reqTwo.Headers["Cookie"])
middlewaresTwo := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewaresTwo, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresTwo[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresTwo[1].(httpclient.MiddlewareName).MiddlewareName())
})
})
t.Run("Test a mixed datasource query with expressions", func(t *testing.T) {
@@ -352,14 +191,6 @@ func TestParseMetricRequest(t *testing.T) {
// Make sure we end up with something valid
_, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq)
assert.NoError(t, err)
t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) {
parsedReq.httpRequest = &http.Request{}
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{})
require.Len(t, enrichers, 2)
require.NotNil(t, enrichers["gIEkMvIVz"])
require.NotNil(t, enrichers["sEx6ZvSVk"])
})
})
t.Run("Header validation", func(t *testing.T) {
@@ -377,23 +208,30 @@ func TestParseMetricRequest(t *testing.T) {
"type": "testdata"
}
}`)
httpreq, _ := http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{}))
httpreq, err := http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{}))
require.NoError(t, err)
reqCtx := &models.ReqContext{
Context: &web.Context{},
}
ctx := ctxkey.Set(context.Background(), reqCtx)
*httpreq = *httpreq.WithContext(ctx)
reqCtx.Req = httpreq
httpreq.Header.Add("X-Datasource-Uid", "gIEkMvIVz")
mr.HTTPRequest = httpreq
_, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr)
_, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr)
require.NoError(t, err)
// With the second value it is OK
httpreq.Header.Add("X-Datasource-Uid", "sEx6ZvSVk")
mr.HTTPRequest = httpreq
_, err = tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr)
_, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr)
require.NoError(t, err)
// Single header with comma syntax
httpreq, _ = http.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader([]byte{}))
httpreq.Header.Set("X-Datasource-Uid", "gIEkMvIVz, sEx6ZvSVk")
mr.HTTPRequest = httpreq
_, err = tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr)
_, err = tc.queryService.parseMetricRequest(httpreq.Context(), tc.signedInUser, true, mr)
require.NoError(t, err)
})
}
@@ -428,7 +266,6 @@ func TestQueryDataMultipleSources(t *testing.T) {
Queries: queries,
Debug: false,
PublicDashboardAccessToken: "abc123",
HTTPRequest: nil,
}
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO)
@@ -479,19 +316,27 @@ func TestQueryDataMultipleSources(t *testing.T) {
Queries: queries,
Debug: false,
PublicDashboardAccessToken: "abc123",
HTTPRequest: nil,
}
// without query parameter
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO)
require.NoError(t, err)
httpreq, _ := http.NewRequest(http.MethodPost, "http://localhost/ds/query?expression=true", bytes.NewReader([]byte{}))
httpreq, err := http.NewRequest(http.MethodPost, "http://localhost/ds/query?expression=true", bytes.NewReader([]byte{}))
require.NoError(t, err)
reqCtx := &models.ReqContext{
Context: &web.Context{},
}
ctx := ctxkey.Set(context.Background(), reqCtx)
*httpreq = *httpreq.WithContext(ctx)
reqCtx.Req = httpreq
httpreq.Header.Add("X-Datasource-Uid", "gIEkMvIVz")
reqDTO.HTTPRequest = httpreq
// with query parameter
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO)
_, err = tc.queryService.QueryData(httpreq.Context(), tc.signedInUser, true, reqDTO)
require.NoError(t, err)
})
@@ -526,7 +371,6 @@ func TestQueryDataMultipleSources(t *testing.T) {
Queries: queries,
Debug: false,
PublicDashboardAccessToken: "abc123",
HTTPRequest: nil,
}
res, err := tc.queryService.QueryData(context.Background(), tc.signedInUser, true, reqDTO)
@@ -538,99 +382,10 @@ func TestQueryDataMultipleSources(t *testing.T) {
})
}
func TestQueryData(t *testing.T) {
t.Run("it auth custom headers to the request", func(t *testing.T) {
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc := setup(t)
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
metricReq := metricRequest()
httpReq, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
metricReq.HTTPRequest = httpReq
_, err = tc.queryService.QueryData(context.Background(), nil, true, metricReq)
require.Nil(t, err)
expected := map[string]string{
"Authorization": "Bearer access-token",
"X-ID-Token": "id-token",
}
require.Equal(t, expected, tc.pluginContext.req.Headers)
})
t.Run("it doesn't add cookie header to the request when keepCookies configured and no cookies provided", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "foo", "bar" ]}`))
require.NoError(t, err)
tc.dataSourceCache.ds.JsonData = json
metricReq := metricRequest()
httpReq, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
metricReq.HTTPRequest = httpReq
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq)
require.NoError(t, err)
require.Empty(t, tc.pluginContext.req.Headers)
})
t.Run("it adds cookie header to the request when keepCookies configured and cookie provided", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "foo", "bar" ]}`))
require.NoError(t, err)
tc.dataSourceCache.ds.JsonData = json
metricReq := metricRequest()
httpReq, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
httpReq.AddCookie(&http.Cookie{Name: "a"})
httpReq.AddCookie(&http.Cookie{Name: "bar", Value: "rab"})
httpReq.AddCookie(&http.Cookie{Name: "b"})
httpReq.AddCookie(&http.Cookie{Name: "foo", Value: "oof"})
httpReq.AddCookie(&http.Cookie{Name: "c"})
metricReq.HTTPRequest = httpReq
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq)
require.NoError(t, err)
require.Equal(t, map[string]string{"Cookie": "bar=rab; foo=oof"}, tc.pluginContext.req.Headers)
})
t.Run("it doesn't adds cookie header to the request when keepCookies configured with login cookie name", func(t *testing.T) {
tc := setup(t)
tc.queryService.cfg.LoginCookieName = "grafana_session"
json, err := simplejson.NewJson([]byte(`{"keepCookies": [ "grafana_session", "bar" ]}`))
require.NoError(t, err)
tc.dataSourceCache.ds.JsonData = json
metricReq := metricRequest()
httpReq, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
httpReq.AddCookie(&http.Cookie{Name: "a"})
httpReq.AddCookie(&http.Cookie{Name: "bar", Value: "rab"})
httpReq.AddCookie(&http.Cookie{Name: "b"})
httpReq.AddCookie(&http.Cookie{Name: "foo", Value: "oof"})
httpReq.AddCookie(&http.Cookie{Name: "c"})
httpReq.AddCookie(&http.Cookie{Name: tc.queryService.cfg.LoginCookieName, Value: "val"})
metricReq.HTTPRequest = httpReq
_, err = tc.queryService.QueryData(context.Background(), tc.signedInUser, true, metricReq)
require.NoError(t, err)
require.Equal(t, map[string]string{"Cookie": "bar=rab"}, tc.pluginContext.req.Headers)
})
}
func setup(t *testing.T) *testContext {
t.Helper()
pc := &fakePluginClient{}
dc := &fakeDataSourceCache{ds: &datasources.DataSource{}}
tc := &fakeOAuthTokenService{}
rv := &fakePluginRequestValidator{}
sqlStore := db.InitTestDB(t)
@@ -645,15 +400,14 @@ func setup(t *testing.T) *testContext {
SimulatePluginFailure: false,
}
exprService := expr.ProvideService(&setting.Cfg{ExpressionsEnabled: true}, pc, fakeDatasourceService)
queryService := ProvideService(setting.NewCfg(), dc, exprService, rv, ds, pc, tc) // provider belonging to this package
queryService := ProvideService(setting.NewCfg(), dc, exprService, rv, ds, pc) // provider belonging to this package
return &testContext{
pluginContext: pc,
secretStore: ss,
dataSourceCache: dc,
oauthTokenService: tc,
pluginRequestValidator: rv,
queryService: queryService,
signedInUser: &user.SignedInUser{OrgID: 1},
signedInUser: &user.SignedInUser{OrgID: 1, Login: "login", Name: "name", Email: "email", OrgRole: roletype.RoleAdmin},
}
}
@@ -661,22 +415,11 @@ type testContext struct {
pluginContext *fakePluginClient
secretStore secretskvs.SecretsKVStore
dataSourceCache *fakeDataSourceCache
oauthTokenService *fakeOAuthTokenService
pluginRequestValidator *fakePluginRequestValidator
queryService *Service // implementation belonging to this package
signedInUser *user.SignedInUser
}
func metricRequest() dtos.MetricRequest {
q, _ := simplejson.NewJson([]byte(`{"datasourceId":1}`))
return dtos.MetricRequest{
From: "",
To: "",
Queries: []*simplejson.Json{q},
Debug: false,
}
}
func metricRequestWithQueries(t *testing.T, rawQueries ...string) dtos.MetricRequest {
t.Helper()
queries := make([]*simplejson.Json, 0)
@@ -701,34 +444,8 @@ func (rv *fakePluginRequestValidator) Validate(dsURL string, req *http.Request)
return rv.err
}
type fakeOAuthTokenService struct {
passThruEnabled bool
token *oauth2.Token
}
func (ts *fakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
return ts.token
}
func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
return nil
}
type fakeDataSourceCache struct {
ds *datasources.DataSource
dsByUid func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error)
ds *datasources.DataSource
}
func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
@@ -736,10 +453,6 @@ func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID in
}
func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if c.dsByUid != nil {
return c.dsByUid(ctx, datasourceUID, user, skipCache)
}
return &datasources.DataSource{
Uid: datasourceUID,
}, nil

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationElasticsearch(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -107,14 +106,3 @@ func TestIntegrationElasticsearch(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationGraphite(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -104,14 +103,3 @@ func TestIntegrationGraphite(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationInflux(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -103,14 +102,3 @@ func TestIntegrationInflux(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationLoki(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -102,14 +101,3 @@ func TestIntegrationLoki(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationOpenTSDB(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -103,14 +102,3 @@ func TestIntegrationOpenTSDB(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -0,0 +1,619 @@
package backendplugin
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/server"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestIntegrationBackendPlugins(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
regularQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
return metricRequestWithQueries(t, fmt.Sprintf(`{
"datasource": {
"uid": "%s"
}
}`, tsCtx.uid))
}
expressionQuery := func(t *testing.T, tsCtx *testScenarioContext) dtos.MetricRequest {
t.Helper()
return metricRequestWithQueries(t, fmt.Sprintf(`{
"refId": "A",
"datasource": {
"uid": "%s",
"type": "%s"
}
}`, tsCtx.uid, tsCtx.testPluginID), `{
"refId": "B",
"datasource": {
"type": "__expr__",
"uid": "__expr__",
"name": "Expression"
},
"type": "math",
"expression": "$A - 50"
}`)
}
newTestScenario(t, "When oauth token not available", func(t *testing.T, tsCtx *testScenarioContext) {
tsCtx.testEnv.OAuthTokenService.Token = nil
tsCtx.runCheckHealthTest(t)
tsCtx.runCallResourceTest(t)
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
})
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
})
})
newTestScenario(t, "When oauth token available", func(t *testing.T, tsCtx *testScenarioContext) {
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
RefreshToken: "refresh-token",
Expiry: time.Now().UTC().Add(24 * time.Hour),
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tsCtx.testEnv.OAuthTokenService.Token = token
tsCtx.runCheckHealthTest(t)
tsCtx.runCallResourceTest(t)
t.Run("regular query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, regularQuery(t, tsCtx))
})
t.Run("expression query", func(t *testing.T) {
tsCtx.runQueryDataTest(t, expressionQuery(t, tsCtx))
})
})
}
type testScenarioContext struct {
testPluginID string
uid string
grafanaListeningAddr string
testEnv *server.TestEnv
outgoingServer *httptest.Server
outgoingRequest *http.Request
backendTestPlugin *testPlugin
rt http.RoundTripper
}
func newTestScenario(t *testing.T, name string, callback func(t *testing.T, ctx *testScenarioContext)) {
tsCtx := testScenarioContext{
testPluginID: "test-plugin",
}
dir, path := testinfra.CreateGrafDir(t, testinfra.GrafanaOpts{
DisableAnonymous: true,
// EnableLog: true,
})
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
tsCtx.grafanaListeningAddr = grafanaListeningAddr
tsCtx.testEnv = testEnv
ctx := context.Background()
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
})
tsCtx.outgoingServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tsCtx.outgoingRequest = r
w.WriteHeader(http.StatusUnauthorized)
}))
t.Cleanup(tsCtx.outgoingServer.Close)
testPlugin, backendTestPlugin := createTestPlugin(tsCtx.testPluginID)
tsCtx.backendTestPlugin = backendTestPlugin
err := testEnv.PluginRegistry.Add(ctx, testPlugin)
require.NoError(t, err)
jsonData := simplejson.NewFromAny(map[string]interface{}{
"httpHeaderName1": "X-CUSTOM-HEADER",
"oauthPassThru": true,
"keepCookies": []string{"cookie1", "cookie3", "grafana_session"},
})
secureJSONData := map[string]string{
"basicAuthPassword": "basicAuthPassword",
"httpHeaderValue1": "custom-header-value",
}
tsCtx.uid = "test-plugin"
err = testEnv.Server.HTTPServer.DataSourcesService.AddDataSource(ctx, &datasources.AddDataSourceCommand{
OrgId: 1,
Access: datasources.DS_ACCESS_PROXY,
Name: "TestPlugin",
Type: tsCtx.testPluginID,
Uid: tsCtx.uid,
Url: tsCtx.outgoingServer.URL,
BasicAuth: true,
BasicAuthUser: "basicAuthUser",
JsonData: jsonData,
SecureJsonData: secureJSONData,
})
require.NoError(t, err)
getDataSourceQuery := &datasources.GetDataSourceQuery{
OrgId: 1,
Uid: tsCtx.uid,
}
err = testEnv.Server.HTTPServer.DataSourcesService.GetDataSource(ctx, getDataSourceQuery)
require.NoError(t, err)
rt, err := testEnv.Server.HTTPServer.DataSourcesService.GetHTTPTransport(ctx, getDataSourceQuery.Result, testEnv.HTTPClientProvider)
require.NoError(t, err)
tsCtx.rt = rt
t.Run(name, func(t *testing.T) {
callback(t, &tsCtx)
})
}
func (tsCtx *testScenarioContext) runQueryDataTest(t *testing.T, mr dtos.MetricRequest) {
t.Run("When calling /api/ds/query should set expected headers on outgoing QueryData and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.QueryDataRequest
}
tsCtx.backendTestPlugin.QueryDataHandler = backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
received = &struct {
ctx context.Context
req *backend.QueryDataRequest
}{ctx, req}
c := http.Client{
Transport: tsCtx.rt,
}
outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil)
require.NoError(t, err)
resp, err := c.Do(outReq)
if err != nil {
return nil, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err)
}
}()
_, err = io.Copy(io.Discard, resp.Body)
if err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err)
}
return &backend.QueryDataResponse{}, nil
})
buf1 := &bytes.Buffer{}
err := json.NewEncoder(buf1).Encode(mr)
require.NoError(t, err)
u := fmt.Sprintf("http://admin:admin@%s/api/ds/query", tsCtx.grafanaListeningAddr)
req, err := http.NewRequest(http.MethodPost, u, buf1)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode, string(b))
t.Cleanup(func() {
err := resp.Body.Close()
require.NoError(t, err)
})
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
})
}
func (tsCtx *testScenarioContext) runCheckHealthTest(t *testing.T) {
t.Run("When calling /api/datasources/uid/:uid/health should set expected headers on outgoing CheckHealth and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.CheckHealthRequest
}
tsCtx.backendTestPlugin.CheckHealthHandler = backend.CheckHealthHandlerFunc(func(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
received = &struct {
ctx context.Context
req *backend.CheckHealthRequest
}{ctx, req}
c := http.Client{
Transport: tsCtx.rt,
}
outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil)
require.NoError(t, err)
resp, err := c.Do(outReq)
if err != nil {
return nil, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err)
}
}()
_, err = io.Copy(io.Discard, resp.Body)
if err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err)
}
return &backend.CheckHealthResult{
Status: backend.HealthStatusOk,
}, nil
})
u := fmt.Sprintf("http://admin:admin@%s/api/datasources/uid/%s/health", tsCtx.grafanaListeningAddr, tsCtx.uid)
req, err := http.NewRequest(http.MethodGet, u, nil)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode, string(b))
t.Cleanup(func() {
err := resp.Body.Close()
require.NoError(t, err)
})
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"])
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
})
}
func (tsCtx *testScenarioContext) runCallResourceTest(t *testing.T) {
t.Run("When calling /api/datasources/uid/:uid/resources should set expected headers on outgoing CallResource and HTTP request", func(t *testing.T) {
var received *struct {
ctx context.Context
req *backend.CallResourceRequest
}
tsCtx.backendTestPlugin.CallResourceHandler = backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
received = &struct {
ctx context.Context
req *backend.CallResourceRequest
}{ctx, req}
c := http.Client{
Transport: tsCtx.rt,
}
outReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tsCtx.outgoingServer.URL, nil)
require.NoError(t, err)
resp, err := c.Do(outReq)
if err != nil {
return err
}
defer func() {
if err := resp.Body.Close(); err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to close body", "error", err)
}
}()
_, err = io.Copy(io.Discard, resp.Body)
if err != nil {
tsCtx.testEnv.Server.HTTPServer.Cfg.Logger.Error("Failed to discard body", "error", err)
}
err = sender.Send(&backend.CallResourceResponse{
Status: http.StatusOK,
})
return err
})
u := fmt.Sprintf("http://admin:admin@%s/api/datasources/uid/%s/resources", tsCtx.grafanaListeningAddr, tsCtx.uid)
req, err := http.NewRequest(http.MethodGet, u, nil)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "cookie1",
})
req.AddCookie(&http.Cookie{
Name: "cookie2",
})
req.AddCookie(&http.Cookie{
Name: "cookie3",
})
req.AddCookie(&http.Cookie{
Name: "grafana_session",
})
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode, string(b))
t.Cleanup(func() {
err := resp.Body.Close()
require.NoError(t, err)
})
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// backend query data request
require.NotNil(t, received)
require.Equal(t, "cookie1=; cookie3=", received.req.Headers["Cookie"][0])
token := tsCtx.testEnv.OAuthTokenService.Token
var expectedAuthHeader string
var expectedTokenHeader string
if token != nil {
expectedAuthHeader = fmt.Sprintf("Bearer %s", token.AccessToken)
expectedTokenHeader = token.Extra("id_token").(string)
require.Equal(t, expectedAuthHeader, received.req.Headers["Authorization"][0])
require.Equal(t, expectedTokenHeader, received.req.Headers["X-ID-Token"][0])
}
// outgoing HTTP request
require.NotNil(t, tsCtx.outgoingRequest)
require.Equal(t, "cookie1=; cookie3=", tsCtx.outgoingRequest.Header.Get("Cookie"))
require.Equal(t, "custom-header-value", tsCtx.outgoingRequest.Header.Get("X-CUSTOM-HEADER"))
if token == nil {
username, pwd, ok := tsCtx.outgoingRequest.BasicAuth()
require.True(t, ok)
require.Equal(t, "basicAuthUser", username)
require.Equal(t, "basicAuthPassword", pwd)
} else {
require.Equal(t, expectedAuthHeader, tsCtx.outgoingRequest.Header.Get("Authorization"))
require.Equal(t, expectedTokenHeader, tsCtx.outgoingRequest.Header.Get("X-ID-Token"))
}
})
}
func createTestPlugin(id string) (*plugins.Plugin, *testPlugin) {
p := &plugins.Plugin{
JSONData: plugins.JSONData{
ID: id,
},
Class: plugins.Core,
}
p.SetLogger(log.New("test-plugin"))
tp := &testPlugin{
pluginID: id,
logger: p.Logger(),
QueryDataHandler: backend.QueryDataHandlerFunc(func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
return &backend.QueryDataResponse{}, nil
}),
}
p.RegisterClient(tp)
return p, tp
}
type testPlugin struct {
pluginID string
logger log.Logger
backend.CheckHealthHandler
backend.CallResourceHandler
backend.QueryDataHandler
backend.StreamHandler
}
func (tp *testPlugin) PluginID() string {
return tp.pluginID
}
func (tp *testPlugin) Logger() log.Logger {
return tp.logger
}
func (tp *testPlugin) Start(ctx context.Context) error {
return nil
}
func (tp *testPlugin) Stop(ctx context.Context) error {
return nil
}
func (tp *testPlugin) IsManaged() bool {
return true
}
func (tp *testPlugin) Exited() bool {
return false
}
func (tp *testPlugin) Decommission() error {
return nil
}
func (tp *testPlugin) IsDecommissioned() bool {
return false
}
func (tp *testPlugin) CollectMetrics(_ context.Context, _ *backend.CollectMetricsRequest) (*backend.CollectMetricsResult, error) {
return nil, backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
if tp.CheckHealthHandler != nil {
return tp.CheckHealthHandler.CheckHealth(ctx, req)
}
return nil, backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
if tp.QueryDataHandler != nil {
return tp.QueryDataHandler.QueryData(ctx, req)
}
return nil, backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
if tp.CallResourceHandler != nil {
return tp.CallResourceHandler.CallResource(ctx, req, sender)
}
return backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) {
if tp.StreamHandler != nil {
return tp.StreamHandler.SubscribeStream(ctx, req)
}
return nil, backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) {
if tp.StreamHandler != nil {
return tp.StreamHandler.PublishStream(ctx, req)
}
return nil, backendplugin.ErrMethodNotImplemented
}
func (tp *testPlugin) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error {
if tp.StreamHandler != nil {
return tp.StreamHandler.RunStream(ctx, req, sender)
}
return backendplugin.ErrMethodNotImplemented
}
func metricRequestWithQueries(t *testing.T, rawQueries ...string) dtos.MetricRequest {
t.Helper()
queries := make([]*simplejson.Json, 0)
for _, q := range rawQueries {
json, err := simplejson.NewJson([]byte(q))
require.NoError(t, err)
queries = append(queries, json)
}
return dtos.MetricRequest{
From: "now-1h",
To: "now",
Queries: queries,
Debug: false,
}
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/tests/testinfra"
"github.com/stretchr/testify/require"
@@ -31,7 +30,7 @@ func TestIntegrationPrometheusBuffered(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -116,7 +115,7 @@ func TestIntegrationPrometheusClient(t *testing.T) {
grafanaListeningAddr, testEnv := testinfra.StartGrafanaEnv(t, dir, path)
ctx := context.Background()
createUser(t, testEnv.SQLStore, user.CreateUserCommand{
testinfra.CreateUser(t, testEnv.SQLStore, user.CreateUserCommand{
DefaultOrgRole: string(org.RoleAdmin),
Password: "admin",
Login: "admin",
@@ -213,14 +212,3 @@ func TestIntegrationPrometheusClient(t *testing.T) {
require.Equal(t, "basicAuthPassword", pwd)
})
}
func createUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/grafana/grafana/pkg/server"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
@@ -307,7 +308,13 @@ func CreateGrafDir(t *testing.T, opts ...GrafanaOpts) (string, string) {
require.NoError(t, err)
_, err = logSection.NewKey("enabled", "false")
require.NoError(t, err)
} else {
serverSection, err := getOrCreateSection("server")
require.NoError(t, err)
_, err = serverSection.NewKey("router_logging", "true")
require.NoError(t, err)
}
if o.GRPCServerAddress != "" {
logSection, err := getOrCreateSection("grpc_server")
require.NoError(t, err)
@@ -356,3 +363,14 @@ type GrafanaOpts struct {
GRPCServerAddress string
QueryRetries int64
}
func CreateUser(t *testing.T, store *sqlstore.SQLStore, cmd user.CreateUserCommand) int64 {
t.Helper()
store.Cfg.AutoAssignOrg = true
store.Cfg.AutoAssignOrgId = 1
u, err := store.CreateUser(context.Background(), cmd)
require.NoError(t, err)
return u.ID
}

View File

@@ -2,7 +2,6 @@ package service
import (
"context"
"fmt"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend"
@@ -14,10 +13,6 @@ import (
"github.com/grafana/grafana/pkg/tsdb/legacydata"
)
var oAuthIsOAuthPassThruEnabledFunc = func(oAuthTokenService oauthtoken.OAuthTokenService, ds *datasources.DataSource) bool {
return oAuthTokenService.IsOAuthPassThruEnabled(ds)
}
type Service struct {
pluginsClient plugins.Client
oAuthTokenService oauthtoken.OAuthTokenService
@@ -45,13 +40,6 @@ func (h *Service) HandleRequest(ctx context.Context, ds *datasources.DataSource,
return legacydata.DataResponse{}, err
}
// Attach Auth information
if oAuthIsOAuthPassThruEnabledFunc(h.oAuthTokenService, ds) {
if token := h.oAuthTokenService.GetCurrentOAuthToken(ctx, query.User); token != nil {
query.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
}
}
resp, err := h.pluginsClient.QueryData(ctx, req)
if err != nil {
return legacydata.DataResponse{}, err

View File

@@ -15,7 +15,6 @@ import (
"github.com/grafana/grafana/pkg/services/datasources"
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/quota/quotatest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
secretskvs "github.com/grafana/grafana/pkg/services/secrets/kvstore"
@@ -25,15 +24,6 @@ import (
func TestHandleRequest(t *testing.T) {
t.Run("Should invoke plugin manager QueryData when handling request for query", func(t *testing.T) {
origOAuthIsOAuthPassThruEnabledFunc := oAuthIsOAuthPassThruEnabledFunc
oAuthIsOAuthPassThruEnabledFunc = func(oAuthTokenService oauthtoken.OAuthTokenService, ds *datasources.DataSource) bool {
return false
}
t.Cleanup(func() {
oAuthIsOAuthPassThruEnabledFunc = origOAuthIsOAuthPassThruEnabledFunc
})
client := &fakePluginsClient{}
var actualReq *backend.QueryDataRequest
client.QueryDataHandlerFunc = func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {