From 5623b5afafbf541eac66fab1d9398756bdf9052c Mon Sep 17 00:00:00 2001 From: Kyle Brandt Date: Mon, 28 Nov 2022 07:40:06 -0500 Subject: [PATCH] SSE: Make sure to forward headers, user and cookies/OAuth token (#58897) Fixes #58793 and Fixes https://github.com/grafana/azure-data-explorer-datasource/issues/513 Co-authored-by: Marcus Efraimsson --- pkg/expr/graph.go | 11 +- pkg/expr/nodes.go | 68 +++--- pkg/expr/transform.go | 5 + .../publicdashboards/api/common_test.go | 2 +- pkg/services/query/query.go | 93 +------- pkg/services/query/query_parsing.go | 157 +++++++++++++ pkg/services/query/query_test.go | 212 +++++++++++++++++- 7 files changed, 426 insertions(+), 122 deletions(-) create mode 100644 pkg/services/query/query_parsing.go diff --git a/pkg/expr/graph.go b/pkg/expr/graph.go index d473af64566..deb915e9f81 100644 --- a/pkg/expr/graph.go +++ b/pkg/expr/graph.go @@ -144,11 +144,12 @@ func (s *Service) buildGraph(req *Request) (*simple.DirectedGraph, error) { } rn := &rawNode{ - Query: rawQueryProp, - RefID: query.RefID, - TimeRange: query.TimeRange, - QueryType: query.QueryType, - DataSource: query.DataSource, + Query: rawQueryProp, + RefID: query.RefID, + TimeRange: query.TimeRange, + QueryType: query.QueryType, + DataSource: query.DataSource, + QueryEnricher: query.QueryEnricher, } var node Node diff --git a/pkg/expr/nodes.go b/pkg/expr/nodes.go index cd6fca63b8b..f679ec38354 100644 --- a/pkg/expr/nodes.go +++ b/pkg/expr/nodes.go @@ -42,11 +42,12 @@ type baseNode struct { } type rawNode struct { - RefID string `json:"refId"` - Query map[string]interface{} - QueryType string - TimeRange TimeRange - DataSource *datasources.DataSource + RefID string `json:"refId"` + Query map[string]interface{} + QueryType string + TimeRange TimeRange + DataSource *datasources.DataSource + QueryEnricher QueryDataRequestEnricher } func (rn *rawNode) GetCommandType() (c CommandType, err error) { @@ -139,8 +140,9 @@ const ( // DSNode is a DPNode that holds a datasource request. type DSNode struct { baseNode - query json.RawMessage - datasource *datasources.DataSource + query json.RawMessage + datasource *datasources.DataSource + queryEnricher QueryDataRequestEnricher orgID int64 queryType string @@ -169,14 +171,15 @@ func (s *Service) buildDSNode(dp *simple.DirectedGraph, rn *rawNode, req *Reques id: dp.NewNode().ID(), refID: rn.RefID, }, - orgID: req.OrgId, - query: json.RawMessage(encodedQuery), - queryType: rn.QueryType, - intervalMS: defaultIntervalMS, - maxDP: defaultMaxDP, - timeRange: rn.TimeRange, - request: *req, - datasource: rn.DataSource, + orgID: req.OrgId, + query: json.RawMessage(encodedQuery), + queryType: rn.QueryType, + intervalMS: defaultIntervalMS, + maxDP: defaultMaxDP, + timeRange: rn.TimeRange, + request: *req, + datasource: rn.DataSource, + queryEnricher: rn.QueryEnricher, } var floatIntervalMS float64 @@ -211,24 +214,29 @@ func (dn *DSNode) Execute(ctx context.Context, now time.Time, _ mathexp.Vars, s OrgID: dn.orgID, DataSourceInstanceSettings: dsInstanceSettings, PluginID: dn.datasource.Type, + User: dn.request.User, } - q := []backend.DataQuery{ - { - RefID: dn.refID, - MaxDataPoints: dn.maxDP, - Interval: time.Duration(int64(time.Millisecond) * dn.intervalMS), - JSON: dn.query, - TimeRange: dn.timeRange.AbsoluteTime(now), - QueryType: dn.queryType, - }, - } - - resp, err := s.dataService.QueryData(ctx, &backend.QueryDataRequest{ + req := &backend.QueryDataRequest{ PluginContext: pc, - Queries: q, - Headers: dn.request.Headers, - }) + Queries: []backend.DataQuery{ + { + RefID: dn.refID, + MaxDataPoints: dn.maxDP, + Interval: time.Duration(int64(time.Millisecond) * dn.intervalMS), + JSON: dn.query, + TimeRange: dn.timeRange.AbsoluteTime(now), + QueryType: dn.queryType, + }, + }, + Headers: dn.request.Headers, + } + + if dn.queryEnricher != nil { + ctx = dn.queryEnricher(ctx, req) + } + + resp, err := s.dataService.QueryData(ctx, req) if err != nil { return mathexp.Results{}, err } diff --git a/pkg/expr/transform.go b/pkg/expr/transform.go index a128ab9ec88..5d5351b04ef 100644 --- a/pkg/expr/transform.go +++ b/pkg/expr/transform.go @@ -35,14 +35,19 @@ type Request struct { Debug bool OrgId int64 Queries []Query + User *backend.User } +// QueryDataRequestEnricher function definition for enriching a backend.QueryDataRequest request. +type QueryDataRequestEnricher func(ctx context.Context, req *backend.QueryDataRequest) context.Context + // Query is like plugins.DataSubQuery, but with a a time range, and only the UID // for the data source. Also interval is a time.Duration. type Query struct { RefID string TimeRange TimeRange DataSource *datasources.DataSource `json:"datasource"` + QueryEnricher QueryDataRequestEnricher JSON json.RawMessage Interval time.Duration QueryType string diff --git a/pkg/services/publicdashboards/api/common_test.go b/pkg/services/publicdashboards/api/common_test.go index 2af9bf90ac3..116c952db6b 100644 --- a/pkg/services/publicdashboards/api/common_test.go +++ b/pkg/services/publicdashboards/api/common_test.go @@ -131,7 +131,7 @@ func buildQueryDataService(t *testing.T, cs datasources.CacheService, fpc *fakeP } return query.ProvideService( - nil, + setting.NewCfg(), cs, nil, &fakePluginRequestValidator{}, diff --git a/pkg/services/query/query.go b/pkg/services/query/query.go index cefd648e19e..c9c3eca36c0 100644 --- a/pkg/services/query/query.go +++ b/pkg/services/query/query.go @@ -3,8 +3,6 @@ package query import ( "context" "fmt" - "net/http" - "strings" "time" "github.com/grafana/grafana/pkg/api/dtos" @@ -133,10 +131,17 @@ func (s *Service) QueryData(ctx context.Context, user *user.SignedInUser, skipCa // handleExpressions handles POST /api/ds/query when there is an expression. func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser, parsedReq *parsedRequest) (*backend.QueryDataResponse, error) { exprReq := expr.Request{ - OrgId: user.OrgID, Queries: []expr.Query{}, } + if user != nil { // for passthrough authentication, SSE does not authenticate + exprReq.User = adapters.BackendUserFromSignedInUser(user) + exprReq.OrgId = user.OrgID + } + + disallowedCookies := []string{s.cfg.LoginCookieName} + queryEnrichers := parsedReq.createDataSourceQueryEnrichers(ctx, user, s.oAuthTokenService, disallowedCookies) + for _, pq := range parsedReq.getFlattenedQueries() { if pq.datasource == nil { return nil, ErrMissingDataSourceInfo.Build(errutil.TemplateData{ @@ -157,6 +162,7 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser From: pq.query.TimeRange.From, To: pq.query.TimeRange.To, }, + QueryEnricher: queryEnrichers[pq.datasource.Uid], }) } @@ -198,10 +204,11 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si Queries: []backend.DataQuery{}, } + disallowedCookies := []string{s.cfg.LoginCookieName} middlewares := []httpclient.Middleware{} if parsedReq.httpRequest != nil { middlewares = append(middlewares, - httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), []string{s.cfg.LoginCookieName}), + httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), disallowedCookies), ) } @@ -218,7 +225,7 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si } if parsedReq.httpRequest != nil { - proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), []string{s.cfg.LoginCookieName}) + proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), disallowedCookies) if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" { req.Headers["Cookie"] = cookieStr } @@ -233,82 +240,6 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si return s.pluginClient.QueryData(ctx, req) } -type parsedQuery struct { - datasource *datasources.DataSource - query backend.DataQuery - rawQuery *simplejson.Json -} - -type parsedRequest struct { - hasExpression bool - parsedQueries map[string][]parsedQuery - dsTypes map[string]bool - httpRequest *http.Request -} - -func (pr parsedRequest) getFlattenedQueries() []parsedQuery { - queries := make([]parsedQuery, 0) - for _, pq := range pr.parsedQueries { - queries = append(queries, pq...) - } - return queries -} - -func (pr parsedRequest) validateRequest() error { - if pr.httpRequest == nil { - return nil - } - - if pr.hasExpression { - hasExpr := pr.httpRequest.URL.Query().Get("expression") - if hasExpr == "" || hasExpr == "true" { - return nil - } - return ErrQueryParamMismatch - } - - vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID)) - count := len(vals) - if count > 0 { // header exists - if count != len(pr.parsedQueries) { - return ErrQueryParamMismatch - } - for _, t := range vals { - if pr.parsedQueries[t] == nil { - return ErrQueryParamMismatch - } - } - } - - vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID)) - count = len(vals) - if count > 0 { // header exists - if count != len(pr.dsTypes) { - return ErrQueryParamMismatch - } - for _, t := range vals { - if !pr.dsTypes[t] { - return ErrQueryParamMismatch - } - } - } - return nil -} - -func splitHeaders(headers []string) []string { - out := []string{} - for _, v := range headers { - if strings.Contains(v, ",") { - for _, sub := range strings.Split(v, ",") { - out = append(out, strings.TrimSpace(sub)) - } - } else { - out = append(out, v) - } - } - return out -} - // parseRequest parses a request into parsed queries grouped by datasource uid func (s *Service) parseMetricRequest(ctx context.Context, user *user.SignedInUser, skipCache bool, reqDTO dtos.MetricRequest) (*parsedRequest, error) { if len(reqDTO.Queries) == 0 { diff --git a/pkg/services/query/query_parsing.go b/pkg/services/query/query_parsing.go new file mode 100644 index 00000000000..9021d6a658c --- /dev/null +++ b/pkg/services/query/query_parsing.go @@ -0,0 +1,157 @@ +package query + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/expr" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" + "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/oauthtoken" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/util/proxyutil" + "golang.org/x/oauth2" +) + +type parsedQuery struct { + datasource *datasources.DataSource + query backend.DataQuery + rawQuery *simplejson.Json +} + +type parsedRequest struct { + hasExpression bool + parsedQueries map[string][]parsedQuery + dsTypes map[string]bool + httpRequest *http.Request +} + +func (pr parsedRequest) getFlattenedQueries() []parsedQuery { + queries := make([]parsedQuery, 0) + for _, pq := range pr.parsedQueries { + queries = append(queries, pq...) + } + return queries +} + +func (pr parsedRequest) validateRequest() error { + if pr.httpRequest == nil { + return nil + } + + if pr.hasExpression { + hasExpr := pr.httpRequest.URL.Query().Get("expression") + if hasExpr == "" || hasExpr == "true" { + return nil + } + return ErrQueryParamMismatch + } + + vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID)) + count := len(vals) + if count > 0 { // header exists + if count != len(pr.parsedQueries) { + return ErrQueryParamMismatch + } + for _, t := range vals { + if pr.parsedQueries[t] == nil { + return ErrQueryParamMismatch + } + } + } + + vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID)) + count = len(vals) + if count > 0 { // header exists + if count != len(pr.dsTypes) { + return ErrQueryParamMismatch + } + for _, t := range vals { + if !pr.dsTypes[t] { + return ErrQueryParamMismatch + } + } + } + return nil +} + +func (pr parsedRequest) createDataSourceQueryEnrichers(ctx context.Context, signedInUser *user.SignedInUser, oAuthTokenService oauthtoken.OAuthTokenService, disallowedCookies []string) map[string]expr.QueryDataRequestEnricher { + datasourcesHeaderProvider := map[string]expr.QueryDataRequestEnricher{} + + if pr.httpRequest == nil { + return datasourcesHeaderProvider + } + + for uid, queries := range pr.parsedQueries { + if expr.IsDataSource(uid) { + continue + } + + if len(queries) == 0 || queries[0].datasource == nil { + continue + } + + if _, exists := datasourcesHeaderProvider[uid]; exists { + continue + } + + ds := queries[0].datasource + allowedCookies := ds.AllowedCookies() + clonedReq := pr.httpRequest.Clone(pr.httpRequest.Context()) + + var token *oauth2.Token + + if oAuthTokenService.IsOAuthPassThruEnabled(ds) { + token = oAuthTokenService.GetCurrentOAuthToken(ctx, signedInUser) + } + + datasourcesHeaderProvider[uid] = func(ctx context.Context, req *backend.QueryDataRequest) context.Context { + if len(req.Headers) == 0 { + req.Headers = map[string]string{} + } + + if len(allowedCookies) > 0 { + proxyutil.ClearCookieHeader(clonedReq, allowedCookies, disallowedCookies) + if cookieStr := clonedReq.Header.Get("Cookie"); cookieStr != "" { + req.Headers["Cookie"] = cookieStr + } + + ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(clonedReq.Cookies(), allowedCookies, disallowedCookies)) + } + + if token != nil { + req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken) + + idToken, ok := token.Extra("id_token").(string) + if ok && idToken != "" { + req.Headers["X-ID-Token"] = idToken + } + + ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedOAuthIdentityMiddleware(token)) + } + + return ctx + } + } + + return datasourcesHeaderProvider +} + +func splitHeaders(headers []string) []string { + out := []string{} + for _, v := range headers { + if strings.Contains(v, ",") { + for _, sub := range strings.Split(v, ",") { + out = append(out, strings.TrimSpace(sub)) + } + } else { + out = append(out, v) + } + } + return out +} diff --git a/pkg/services/query/query_test.go b/pkg/services/query/query_test.go index af81ea8f8b4..1d22e39ece4 100644 --- a/pkg/services/query/query_test.go +++ b/pkg/services/query/query_test.go @@ -5,9 +5,11 @@ import ( "context" "errors" "net/http" + "net/http/httptest" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -16,6 +18,7 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/expr" "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" @@ -33,9 +36,32 @@ import ( ) func TestParseMetricRequest(t *testing.T) { - tc := setup(t) - t.Run("Test a simple single datasource query", func(t *testing.T) { + tc := setup(t) + json, err := simplejson.NewJson([]byte(`{ + "keepCookies": [ "cookie1", "cookie3", "login" ] + }`)) + require.NoError(t, err) + tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { + if datasourceUID == "gIEkMvIVz" { + return &datasources.DataSource{ + Uid: "gIEkMvIVz", + JsonData: json, + }, nil + } + + return nil, nil + } + + token := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + } + token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) + + tc.oauthTokenService.passThruEnabled = true + tc.oauthTokenService.token = token + mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -56,9 +82,61 @@ func TestParseMetricRequest(t *testing.T) { assert.Len(t, parsedReq.parsedQueries, 1) assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") assert.Len(t, parsedReq.getFlattenedQueries(), 2) + + t.Run("createDataSourceQueryEnrichers should return 0 enrichers when no HTTP request", func(t *testing.T) { + enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{}) + require.Empty(t, enrichers) + }) + + t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) { + parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) + + enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) + require.Len(t, enrichers, 1) + require.NotNil(t, enrichers["gIEkMvIVz"]) + req := &backend.QueryDataRequest{} + ctx := enrichers["gIEkMvIVz"](context.Background(), req) + require.Len(t, req.Headers, 3) + require.Equal(t, "Bearer access-token", req.Headers["Authorization"]) + require.Equal(t, "id-token", req.Headers["X-ID-Token"]) + require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"]) + middlewares := httpclient.ContextualMiddlewareFromContext(ctx) + require.Len(t, middlewares, 2) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName()) + }) }) t.Run("Test a single datasource query with expressions", func(t *testing.T) { + tc := setup(t) + json, err := simplejson.NewJson([]byte(`{ + "keepCookies": [ "cookie1", "cookie3", "login" ] + }`)) + require.NoError(t, err) + tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { + if datasourceUID == "gIEkMvIVz" { + return &datasources.DataSource{ + Uid: "gIEkMvIVz", + JsonData: json, + }, nil + } + + return nil, nil + } + + token := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + } + token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) + + tc.oauthTokenService.passThruEnabled = true + tc.oauthTokenService.token = token + mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -85,9 +163,68 @@ func TestParseMetricRequest(t *testing.T) { // Make sure we end up with something valid _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) assert.NoError(t, err) + + t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) { + parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) + + enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) + require.Len(t, enrichers, 1) + require.NotNil(t, enrichers["gIEkMvIVz"]) + + req := &backend.QueryDataRequest{} + ctx := enrichers["gIEkMvIVz"](context.Background(), req) + require.Len(t, req.Headers, 3) + require.Equal(t, "Bearer access-token", req.Headers["Authorization"]) + require.Equal(t, "id-token", req.Headers["X-ID-Token"]) + require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"]) + middlewares := httpclient.ContextualMiddlewareFromContext(ctx) + require.Len(t, middlewares, 2) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName()) + }) }) t.Run("Test a simple mixed datasource query", func(t *testing.T) { + tc := setup(t) + json, err := simplejson.NewJson([]byte(`{ + "keepCookies": [ "cookie1", "cookie3", "login" ] + }`)) + require.NoError(t, err) + json2, err := simplejson.NewJson([]byte(`{ + "keepCookies": [ "cookie2" ] + }`)) + require.NoError(t, err) + tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { + if datasourceUID == "gIEkMvIVz" { + return &datasources.DataSource{ + Uid: "gIEkMvIVz", + JsonData: json, + }, nil + } + + if datasourceUID == "sEx6ZvSVk" { + return &datasources.DataSource{ + Uid: "sEx6ZvSVk", + JsonData: json2, + }, nil + } + + return nil, nil + } + + token := &oauth2.Token{ + TokenType: "bearer", + AccessToken: "access-token", + } + token = token.WithExtra(map[string]interface{}{"id_token": "id-token"}) + + tc.oauthTokenService.passThruEnabled = true + tc.oauthTokenService.token = token + mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -100,6 +237,12 @@ func TestParseMetricRequest(t *testing.T) { "uid": "sEx6ZvSVk", "type": "testdata" } + }`, `{ + "refId": "C", + "datasource": { + "uid": "sEx6ZvSVk", + "type": "testdata" + } }`) parsedReq, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) require.NoError(t, err) @@ -107,11 +250,51 @@ func TestParseMetricRequest(t *testing.T) { assert.False(t, parsedReq.hasExpression) assert.Len(t, parsedReq.parsedQueries, 2) assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") + assert.Len(t, parsedReq.parsedQueries["gIEkMvIVz"], 1) assert.Contains(t, parsedReq.parsedQueries, "sEx6ZvSVk") - assert.Len(t, parsedReq.getFlattenedQueries(), 2) + assert.Len(t, parsedReq.parsedQueries["sEx6ZvSVk"], 2) + assert.Len(t, parsedReq.getFlattenedQueries(), 3) + + t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) { + parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"}) + parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"}) + + enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"}) + require.Len(t, enrichers, 2) + + enricherOne := enrichers["gIEkMvIVz"] + require.NotNil(t, enricherOne) + reqOne := &backend.QueryDataRequest{} + ctx := enricherOne(context.Background(), reqOne) + require.Len(t, reqOne.Headers, 3) + require.Equal(t, "Bearer access-token", reqOne.Headers["Authorization"]) + require.Equal(t, "id-token", reqOne.Headers["X-ID-Token"]) + require.Equal(t, "cookie1=; cookie3=", reqOne.Headers["Cookie"]) + middlewaresOne := httpclient.ContextualMiddlewareFromContext(ctx) + require.Len(t, middlewaresOne, 2) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresOne[0].(httpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresOne[1].(httpclient.MiddlewareName).MiddlewareName()) + + enricherTwo := enrichers["sEx6ZvSVk"] + require.NotNil(t, enricherTwo) + reqTwo := &backend.QueryDataRequest{} + ctx = enricherTwo(context.Background(), reqTwo) + require.Len(t, reqTwo.Headers, 3) + require.Equal(t, "Bearer access-token", reqTwo.Headers["Authorization"]) + require.Equal(t, "id-token", reqTwo.Headers["X-ID-Token"]) + require.Equal(t, "cookie2=", reqTwo.Headers["Cookie"]) + middlewaresTwo := httpclient.ContextualMiddlewareFromContext(ctx) + require.Len(t, middlewaresTwo, 2) + require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresTwo[0].(httpclient.MiddlewareName).MiddlewareName()) + require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresTwo[1].(httpclient.MiddlewareName).MiddlewareName()) + }) }) t.Run("Test a mixed datasource query with expressions", func(t *testing.T) { + tc := setup(t) mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -169,9 +352,18 @@ func TestParseMetricRequest(t *testing.T) { // Make sure we end up with something valid _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) assert.NoError(t, err) + + t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) { + parsedReq.httpRequest = &http.Request{} + enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{}) + require.Len(t, enrichers, 2) + require.NotNil(t, enrichers["gIEkMvIVz"]) + require.NotNil(t, enrichers["sEx6ZvSVk"]) + }) }) t.Run("Header validation", func(t *testing.T) { + tc := setup(t) mr := metricRequestWithQueries(t, `{ "refId": "A", "datasource": { @@ -351,7 +543,12 @@ func TestQueryData(t *testing.T) { tc.oauthTokenService.passThruEnabled = true tc.oauthTokenService.token = token - _, err := tc.queryService.QueryData(context.Background(), nil, true, metricRequest()) + metricReq := metricRequest() + httpReq, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + metricReq.HTTPRequest = httpReq + + _, err = tc.queryService.QueryData(context.Background(), nil, true, metricReq) require.Nil(t, err) expected := map[string]string{ @@ -523,7 +720,8 @@ func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models. } type fakeDataSourceCache struct { - ds *datasources.DataSource + ds *datasources.DataSource + dsByUid func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) } func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { @@ -531,6 +729,10 @@ func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID in } func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { + if c.dsByUid != nil { + return c.dsByUid(ctx, datasourceUID, user, skipCache) + } + return &datasources.DataSource{ Uid: datasourceUID, }, nil