SSE: Make sure to forward headers, user and cookies/OAuth token (#58897)

Fixes #58793 and Fixes https://github.com/grafana/azure-data-explorer-datasource/issues/513

Co-authored-by: Marcus Efraimsson <marcus.efraimsson@gmail.com>
This commit is contained in:
Kyle Brandt
2022-11-28 07:40:06 -05:00
committed by GitHub
parent ba37516690
commit 5623b5afaf
7 changed files with 426 additions and 122 deletions

View File

@@ -144,11 +144,12 @@ func (s *Service) buildGraph(req *Request) (*simple.DirectedGraph, error) {
} }
rn := &rawNode{ rn := &rawNode{
Query: rawQueryProp, Query: rawQueryProp,
RefID: query.RefID, RefID: query.RefID,
TimeRange: query.TimeRange, TimeRange: query.TimeRange,
QueryType: query.QueryType, QueryType: query.QueryType,
DataSource: query.DataSource, DataSource: query.DataSource,
QueryEnricher: query.QueryEnricher,
} }
var node Node var node Node

View File

@@ -42,11 +42,12 @@ type baseNode struct {
} }
type rawNode struct { type rawNode struct {
RefID string `json:"refId"` RefID string `json:"refId"`
Query map[string]interface{} Query map[string]interface{}
QueryType string QueryType string
TimeRange TimeRange TimeRange TimeRange
DataSource *datasources.DataSource DataSource *datasources.DataSource
QueryEnricher QueryDataRequestEnricher
} }
func (rn *rawNode) GetCommandType() (c CommandType, err error) { func (rn *rawNode) GetCommandType() (c CommandType, err error) {
@@ -139,8 +140,9 @@ const (
// DSNode is a DPNode that holds a datasource request. // DSNode is a DPNode that holds a datasource request.
type DSNode struct { type DSNode struct {
baseNode baseNode
query json.RawMessage query json.RawMessage
datasource *datasources.DataSource datasource *datasources.DataSource
queryEnricher QueryDataRequestEnricher
orgID int64 orgID int64
queryType string queryType string
@@ -169,14 +171,15 @@ func (s *Service) buildDSNode(dp *simple.DirectedGraph, rn *rawNode, req *Reques
id: dp.NewNode().ID(), id: dp.NewNode().ID(),
refID: rn.RefID, refID: rn.RefID,
}, },
orgID: req.OrgId, orgID: req.OrgId,
query: json.RawMessage(encodedQuery), query: json.RawMessage(encodedQuery),
queryType: rn.QueryType, queryType: rn.QueryType,
intervalMS: defaultIntervalMS, intervalMS: defaultIntervalMS,
maxDP: defaultMaxDP, maxDP: defaultMaxDP,
timeRange: rn.TimeRange, timeRange: rn.TimeRange,
request: *req, request: *req,
datasource: rn.DataSource, datasource: rn.DataSource,
queryEnricher: rn.QueryEnricher,
} }
var floatIntervalMS float64 var floatIntervalMS float64
@@ -211,24 +214,29 @@ func (dn *DSNode) Execute(ctx context.Context, now time.Time, _ mathexp.Vars, s
OrgID: dn.orgID, OrgID: dn.orgID,
DataSourceInstanceSettings: dsInstanceSettings, DataSourceInstanceSettings: dsInstanceSettings,
PluginID: dn.datasource.Type, PluginID: dn.datasource.Type,
User: dn.request.User,
} }
q := []backend.DataQuery{ req := &backend.QueryDataRequest{
{
RefID: dn.refID,
MaxDataPoints: dn.maxDP,
Interval: time.Duration(int64(time.Millisecond) * dn.intervalMS),
JSON: dn.query,
TimeRange: dn.timeRange.AbsoluteTime(now),
QueryType: dn.queryType,
},
}
resp, err := s.dataService.QueryData(ctx, &backend.QueryDataRequest{
PluginContext: pc, PluginContext: pc,
Queries: q, Queries: []backend.DataQuery{
Headers: dn.request.Headers, {
}) RefID: dn.refID,
MaxDataPoints: dn.maxDP,
Interval: time.Duration(int64(time.Millisecond) * dn.intervalMS),
JSON: dn.query,
TimeRange: dn.timeRange.AbsoluteTime(now),
QueryType: dn.queryType,
},
},
Headers: dn.request.Headers,
}
if dn.queryEnricher != nil {
ctx = dn.queryEnricher(ctx, req)
}
resp, err := s.dataService.QueryData(ctx, req)
if err != nil { if err != nil {
return mathexp.Results{}, err return mathexp.Results{}, err
} }

View File

@@ -35,14 +35,19 @@ type Request struct {
Debug bool Debug bool
OrgId int64 OrgId int64
Queries []Query Queries []Query
User *backend.User
} }
// QueryDataRequestEnricher function definition for enriching a backend.QueryDataRequest request.
type QueryDataRequestEnricher func(ctx context.Context, req *backend.QueryDataRequest) context.Context
// Query is like plugins.DataSubQuery, but with a a time range, and only the UID // Query is like plugins.DataSubQuery, but with a a time range, and only the UID
// for the data source. Also interval is a time.Duration. // for the data source. Also interval is a time.Duration.
type Query struct { type Query struct {
RefID string RefID string
TimeRange TimeRange TimeRange TimeRange
DataSource *datasources.DataSource `json:"datasource"` DataSource *datasources.DataSource `json:"datasource"`
QueryEnricher QueryDataRequestEnricher
JSON json.RawMessage JSON json.RawMessage
Interval time.Duration Interval time.Duration
QueryType string QueryType string

View File

@@ -131,7 +131,7 @@ func buildQueryDataService(t *testing.T, cs datasources.CacheService, fpc *fakeP
} }
return query.ProvideService( return query.ProvideService(
nil, setting.NewCfg(),
cs, cs,
nil, nil,
&fakePluginRequestValidator{}, &fakePluginRequestValidator{},

View File

@@ -3,8 +3,6 @@ package query
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"strings"
"time" "time"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@@ -133,10 +131,17 @@ func (s *Service) QueryData(ctx context.Context, user *user.SignedInUser, skipCa
// handleExpressions handles POST /api/ds/query when there is an expression. // handleExpressions handles POST /api/ds/query when there is an expression.
func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser, parsedReq *parsedRequest) (*backend.QueryDataResponse, error) { func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser, parsedReq *parsedRequest) (*backend.QueryDataResponse, error) {
exprReq := expr.Request{ exprReq := expr.Request{
OrgId: user.OrgID,
Queries: []expr.Query{}, Queries: []expr.Query{},
} }
if user != nil { // for passthrough authentication, SSE does not authenticate
exprReq.User = adapters.BackendUserFromSignedInUser(user)
exprReq.OrgId = user.OrgID
}
disallowedCookies := []string{s.cfg.LoginCookieName}
queryEnrichers := parsedReq.createDataSourceQueryEnrichers(ctx, user, s.oAuthTokenService, disallowedCookies)
for _, pq := range parsedReq.getFlattenedQueries() { for _, pq := range parsedReq.getFlattenedQueries() {
if pq.datasource == nil { if pq.datasource == nil {
return nil, ErrMissingDataSourceInfo.Build(errutil.TemplateData{ return nil, ErrMissingDataSourceInfo.Build(errutil.TemplateData{
@@ -157,6 +162,7 @@ func (s *Service) handleExpressions(ctx context.Context, user *user.SignedInUser
From: pq.query.TimeRange.From, From: pq.query.TimeRange.From,
To: pq.query.TimeRange.To, To: pq.query.TimeRange.To,
}, },
QueryEnricher: queryEnrichers[pq.datasource.Uid],
}) })
} }
@@ -198,10 +204,11 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si
Queries: []backend.DataQuery{}, Queries: []backend.DataQuery{},
} }
disallowedCookies := []string{s.cfg.LoginCookieName}
middlewares := []httpclient.Middleware{} middlewares := []httpclient.Middleware{}
if parsedReq.httpRequest != nil { if parsedReq.httpRequest != nil {
middlewares = append(middlewares, middlewares = append(middlewares,
httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), []string{s.cfg.LoginCookieName}), httpclientprovider.ForwardedCookiesMiddleware(parsedReq.httpRequest.Cookies(), ds.AllowedCookies(), disallowedCookies),
) )
} }
@@ -218,7 +225,7 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si
} }
if parsedReq.httpRequest != nil { if parsedReq.httpRequest != nil {
proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), []string{s.cfg.LoginCookieName}) proxyutil.ClearCookieHeader(parsedReq.httpRequest, ds.AllowedCookies(), disallowedCookies)
if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" { if cookieStr := parsedReq.httpRequest.Header.Get("Cookie"); cookieStr != "" {
req.Headers["Cookie"] = cookieStr req.Headers["Cookie"] = cookieStr
} }
@@ -233,82 +240,6 @@ func (s *Service) handleQuerySingleDatasource(ctx context.Context, user *user.Si
return s.pluginClient.QueryData(ctx, req) return s.pluginClient.QueryData(ctx, req)
} }
type parsedQuery struct {
datasource *datasources.DataSource
query backend.DataQuery
rawQuery *simplejson.Json
}
type parsedRequest struct {
hasExpression bool
parsedQueries map[string][]parsedQuery
dsTypes map[string]bool
httpRequest *http.Request
}
func (pr parsedRequest) getFlattenedQueries() []parsedQuery {
queries := make([]parsedQuery, 0)
for _, pq := range pr.parsedQueries {
queries = append(queries, pq...)
}
return queries
}
func (pr parsedRequest) validateRequest() error {
if pr.httpRequest == nil {
return nil
}
if pr.hasExpression {
hasExpr := pr.httpRequest.URL.Query().Get("expression")
if hasExpr == "" || hasExpr == "true" {
return nil
}
return ErrQueryParamMismatch
}
vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID))
count := len(vals)
if count > 0 { // header exists
if count != len(pr.parsedQueries) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if pr.parsedQueries[t] == nil {
return ErrQueryParamMismatch
}
}
}
vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID))
count = len(vals)
if count > 0 { // header exists
if count != len(pr.dsTypes) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if !pr.dsTypes[t] {
return ErrQueryParamMismatch
}
}
}
return nil
}
func splitHeaders(headers []string) []string {
out := []string{}
for _, v := range headers {
if strings.Contains(v, ",") {
for _, sub := range strings.Split(v, ",") {
out = append(out, strings.TrimSpace(sub))
}
} else {
out = append(out, v)
}
}
return out
}
// parseRequest parses a request into parsed queries grouped by datasource uid // parseRequest parses a request into parsed queries grouped by datasource uid
func (s *Service) parseMetricRequest(ctx context.Context, user *user.SignedInUser, skipCache bool, reqDTO dtos.MetricRequest) (*parsedRequest, error) { func (s *Service) parseMetricRequest(ctx context.Context, user *user.SignedInUser, skipCache bool, reqDTO dtos.MetricRequest) (*parsedRequest, error) {
if len(reqDTO.Queries) == 0 { if len(reqDTO.Queries) == 0 {

View File

@@ -0,0 +1,157 @@
package query
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/proxyutil"
"golang.org/x/oauth2"
)
type parsedQuery struct {
datasource *datasources.DataSource
query backend.DataQuery
rawQuery *simplejson.Json
}
type parsedRequest struct {
hasExpression bool
parsedQueries map[string][]parsedQuery
dsTypes map[string]bool
httpRequest *http.Request
}
func (pr parsedRequest) getFlattenedQueries() []parsedQuery {
queries := make([]parsedQuery, 0)
for _, pq := range pr.parsedQueries {
queries = append(queries, pq...)
}
return queries
}
func (pr parsedRequest) validateRequest() error {
if pr.httpRequest == nil {
return nil
}
if pr.hasExpression {
hasExpr := pr.httpRequest.URL.Query().Get("expression")
if hasExpr == "" || hasExpr == "true" {
return nil
}
return ErrQueryParamMismatch
}
vals := splitHeaders(pr.httpRequest.Header.Values(HeaderDatasourceUID))
count := len(vals)
if count > 0 { // header exists
if count != len(pr.parsedQueries) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if pr.parsedQueries[t] == nil {
return ErrQueryParamMismatch
}
}
}
vals = splitHeaders(pr.httpRequest.Header.Values(HeaderPluginID))
count = len(vals)
if count > 0 { // header exists
if count != len(pr.dsTypes) {
return ErrQueryParamMismatch
}
for _, t := range vals {
if !pr.dsTypes[t] {
return ErrQueryParamMismatch
}
}
}
return nil
}
func (pr parsedRequest) createDataSourceQueryEnrichers(ctx context.Context, signedInUser *user.SignedInUser, oAuthTokenService oauthtoken.OAuthTokenService, disallowedCookies []string) map[string]expr.QueryDataRequestEnricher {
datasourcesHeaderProvider := map[string]expr.QueryDataRequestEnricher{}
if pr.httpRequest == nil {
return datasourcesHeaderProvider
}
for uid, queries := range pr.parsedQueries {
if expr.IsDataSource(uid) {
continue
}
if len(queries) == 0 || queries[0].datasource == nil {
continue
}
if _, exists := datasourcesHeaderProvider[uid]; exists {
continue
}
ds := queries[0].datasource
allowedCookies := ds.AllowedCookies()
clonedReq := pr.httpRequest.Clone(pr.httpRequest.Context())
var token *oauth2.Token
if oAuthTokenService.IsOAuthPassThruEnabled(ds) {
token = oAuthTokenService.GetCurrentOAuthToken(ctx, signedInUser)
}
datasourcesHeaderProvider[uid] = func(ctx context.Context, req *backend.QueryDataRequest) context.Context {
if len(req.Headers) == 0 {
req.Headers = map[string]string{}
}
if len(allowedCookies) > 0 {
proxyutil.ClearCookieHeader(clonedReq, allowedCookies, disallowedCookies)
if cookieStr := clonedReq.Header.Get("Cookie"); cookieStr != "" {
req.Headers["Cookie"] = cookieStr
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedCookiesMiddleware(clonedReq.Cookies(), allowedCookies, disallowedCookies))
}
if token != nil {
req.Headers["Authorization"] = fmt.Sprintf("%s %s", token.Type(), token.AccessToken)
idToken, ok := token.Extra("id_token").(string)
if ok && idToken != "" {
req.Headers["X-ID-Token"] = idToken
}
ctx = httpclient.WithContextualMiddleware(ctx, httpclientprovider.ForwardedOAuthIdentityMiddleware(token))
}
return ctx
}
}
return datasourcesHeaderProvider
}
func splitHeaders(headers []string) []string {
out := []string{}
for _, v := range headers {
if strings.Contains(v, ",") {
for _, sub := range strings.Split(v, ",") {
out = append(out, strings.TrimSpace(sub))
}
} else {
out = append(out, v)
}
}
return out
}

View File

@@ -5,9 +5,11 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -16,6 +18,7 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/expr" "github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins"
@@ -33,9 +36,32 @@ import (
) )
func TestParseMetricRequest(t *testing.T) { func TestParseMetricRequest(t *testing.T) {
tc := setup(t)
t.Run("Test a simple single datasource query", func(t *testing.T) { t.Run("Test a simple single datasource query", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{ mr := metricRequestWithQueries(t, `{
"refId": "A", "refId": "A",
"datasource": { "datasource": {
@@ -56,9 +82,61 @@ func TestParseMetricRequest(t *testing.T) {
assert.Len(t, parsedReq.parsedQueries, 1) assert.Len(t, parsedReq.parsedQueries, 1)
assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz")
assert.Len(t, parsedReq.getFlattenedQueries(), 2) assert.Len(t, parsedReq.getFlattenedQueries(), 2)
t.Run("createDataSourceQueryEnrichers should return 0 enrichers when no HTTP request", func(t *testing.T) {
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{})
require.Empty(t, enrichers)
})
t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 1)
require.NotNil(t, enrichers["gIEkMvIVz"])
req := &backend.QueryDataRequest{}
ctx := enrichers["gIEkMvIVz"](context.Background(), req)
require.Len(t, req.Headers, 3)
require.Equal(t, "Bearer access-token", req.Headers["Authorization"])
require.Equal(t, "id-token", req.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"])
middlewares := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewares, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName())
})
}) })
t.Run("Test a single datasource query with expressions", func(t *testing.T) { t.Run("Test a single datasource query with expressions", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{ mr := metricRequestWithQueries(t, `{
"refId": "A", "refId": "A",
"datasource": { "datasource": {
@@ -85,9 +163,68 @@ func TestParseMetricRequest(t *testing.T) {
// Make sure we end up with something valid // Make sure we end up with something valid
_, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq)
assert.NoError(t, err) assert.NoError(t, err)
t.Run("createDataSourceQueryEnrichers should return 1 enricher", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 1)
require.NotNil(t, enrichers["gIEkMvIVz"])
req := &backend.QueryDataRequest{}
ctx := enrichers["gIEkMvIVz"](context.Background(), req)
require.Len(t, req.Headers, 3)
require.Equal(t, "Bearer access-token", req.Headers["Authorization"])
require.Equal(t, "id-token", req.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", req.Headers["Cookie"])
middlewares := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewares, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewares[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewares[1].(httpclient.MiddlewareName).MiddlewareName())
})
}) })
t.Run("Test a simple mixed datasource query", func(t *testing.T) { t.Run("Test a simple mixed datasource query", func(t *testing.T) {
tc := setup(t)
json, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie1", "cookie3", "login" ]
}`))
require.NoError(t, err)
json2, err := simplejson.NewJson([]byte(`{
"keepCookies": [ "cookie2" ]
}`))
require.NoError(t, err)
tc.dataSourceCache.dsByUid = func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if datasourceUID == "gIEkMvIVz" {
return &datasources.DataSource{
Uid: "gIEkMvIVz",
JsonData: json,
}, nil
}
if datasourceUID == "sEx6ZvSVk" {
return &datasources.DataSource{
Uid: "sEx6ZvSVk",
JsonData: json2,
}, nil
}
return nil, nil
}
token := &oauth2.Token{
TokenType: "bearer",
AccessToken: "access-token",
}
token = token.WithExtra(map[string]interface{}{"id_token": "id-token"})
tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token
mr := metricRequestWithQueries(t, `{ mr := metricRequestWithQueries(t, `{
"refId": "A", "refId": "A",
"datasource": { "datasource": {
@@ -100,6 +237,12 @@ func TestParseMetricRequest(t *testing.T) {
"uid": "sEx6ZvSVk", "uid": "sEx6ZvSVk",
"type": "testdata" "type": "testdata"
} }
}`, `{
"refId": "C",
"datasource": {
"uid": "sEx6ZvSVk",
"type": "testdata"
}
}`) }`)
parsedReq, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr) parsedReq, err := tc.queryService.parseMetricRequest(context.Background(), tc.signedInUser, true, mr)
require.NoError(t, err) require.NoError(t, err)
@@ -107,11 +250,51 @@ func TestParseMetricRequest(t *testing.T) {
assert.False(t, parsedReq.hasExpression) assert.False(t, parsedReq.hasExpression)
assert.Len(t, parsedReq.parsedQueries, 2) assert.Len(t, parsedReq.parsedQueries, 2)
assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz") assert.Contains(t, parsedReq.parsedQueries, "gIEkMvIVz")
assert.Len(t, parsedReq.parsedQueries["gIEkMvIVz"], 1)
assert.Contains(t, parsedReq.parsedQueries, "sEx6ZvSVk") assert.Contains(t, parsedReq.parsedQueries, "sEx6ZvSVk")
assert.Len(t, parsedReq.getFlattenedQueries(), 2) assert.Len(t, parsedReq.parsedQueries["sEx6ZvSVk"], 2)
assert.Len(t, parsedReq.getFlattenedQueries(), 3)
t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) {
parsedReq.httpRequest = httptest.NewRequest(http.MethodGet, "/", nil)
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie1"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie2"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "cookie3"})
parsedReq.httpRequest.AddCookie(&http.Cookie{Name: "login"})
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{"login"})
require.Len(t, enrichers, 2)
enricherOne := enrichers["gIEkMvIVz"]
require.NotNil(t, enricherOne)
reqOne := &backend.QueryDataRequest{}
ctx := enricherOne(context.Background(), reqOne)
require.Len(t, reqOne.Headers, 3)
require.Equal(t, "Bearer access-token", reqOne.Headers["Authorization"])
require.Equal(t, "id-token", reqOne.Headers["X-ID-Token"])
require.Equal(t, "cookie1=; cookie3=", reqOne.Headers["Cookie"])
middlewaresOne := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewaresOne, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresOne[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresOne[1].(httpclient.MiddlewareName).MiddlewareName())
enricherTwo := enrichers["sEx6ZvSVk"]
require.NotNil(t, enricherTwo)
reqTwo := &backend.QueryDataRequest{}
ctx = enricherTwo(context.Background(), reqTwo)
require.Len(t, reqTwo.Headers, 3)
require.Equal(t, "Bearer access-token", reqTwo.Headers["Authorization"])
require.Equal(t, "id-token", reqTwo.Headers["X-ID-Token"])
require.Equal(t, "cookie2=", reqTwo.Headers["Cookie"])
middlewaresTwo := httpclient.ContextualMiddlewareFromContext(ctx)
require.Len(t, middlewaresTwo, 2)
require.Equal(t, httpclientprovider.ForwardedCookiesMiddlewareName, middlewaresTwo[0].(httpclient.MiddlewareName).MiddlewareName())
require.Equal(t, httpclientprovider.ForwardedOAuthIdentityMiddlewareName, middlewaresTwo[1].(httpclient.MiddlewareName).MiddlewareName())
})
}) })
t.Run("Test a mixed datasource query with expressions", func(t *testing.T) { t.Run("Test a mixed datasource query with expressions", func(t *testing.T) {
tc := setup(t)
mr := metricRequestWithQueries(t, `{ mr := metricRequestWithQueries(t, `{
"refId": "A", "refId": "A",
"datasource": { "datasource": {
@@ -169,9 +352,18 @@ func TestParseMetricRequest(t *testing.T) {
// Make sure we end up with something valid // Make sure we end up with something valid
_, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq) _, err = tc.queryService.handleExpressions(context.Background(), tc.signedInUser, parsedReq)
assert.NoError(t, err) assert.NoError(t, err)
t.Run("createDataSourceQueryEnrichers should return 2 enrichers", func(t *testing.T) {
parsedReq.httpRequest = &http.Request{}
enrichers := parsedReq.createDataSourceQueryEnrichers(context.Background(), nil, tc.oauthTokenService, []string{})
require.Len(t, enrichers, 2)
require.NotNil(t, enrichers["gIEkMvIVz"])
require.NotNil(t, enrichers["sEx6ZvSVk"])
})
}) })
t.Run("Header validation", func(t *testing.T) { t.Run("Header validation", func(t *testing.T) {
tc := setup(t)
mr := metricRequestWithQueries(t, `{ mr := metricRequestWithQueries(t, `{
"refId": "A", "refId": "A",
"datasource": { "datasource": {
@@ -351,7 +543,12 @@ func TestQueryData(t *testing.T) {
tc.oauthTokenService.passThruEnabled = true tc.oauthTokenService.passThruEnabled = true
tc.oauthTokenService.token = token tc.oauthTokenService.token = token
_, err := tc.queryService.QueryData(context.Background(), nil, true, metricRequest()) metricReq := metricRequest()
httpReq, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
metricReq.HTTPRequest = httpReq
_, err = tc.queryService.QueryData(context.Background(), nil, true, metricReq)
require.Nil(t, err) require.Nil(t, err)
expected := map[string]string{ expected := map[string]string{
@@ -523,7 +720,8 @@ func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.
} }
type fakeDataSourceCache struct { type fakeDataSourceCache struct {
ds *datasources.DataSource ds *datasources.DataSource
dsByUid func(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error)
} }
func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID int64, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
@@ -531,6 +729,10 @@ func (c *fakeDataSourceCache) GetDatasource(ctx context.Context, datasourceID in
} }
func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) { func (c *fakeDataSourceCache) GetDatasourceByUID(ctx context.Context, datasourceUID string, user *user.SignedInUser, skipCache bool) (*datasources.DataSource, error) {
if c.dsByUid != nil {
return c.dsByUid(ctx, datasourceUID, user, skipCache)
}
return &datasources.DataSource{ return &datasources.DataSource{
Uid: datasourceUID, Uid: datasourceUID,
}, nil }, nil