diff --git a/pkg/tsdb/loki/api.go b/pkg/tsdb/loki/api.go index d1d1eb6f813..63c99bd99b6 100644 --- a/pkg/tsdb/loki/api.go +++ b/pkg/tsdb/loki/api.go @@ -16,16 +16,23 @@ import ( ) type LokiAPI struct { - client *http.Client - url string - log log.Logger + client *http.Client + url string + log log.Logger + oauthToken string } -func newLokiAPI(client *http.Client, url string, log log.Logger) *LokiAPI { - return &LokiAPI{client: client, url: url, log: log} +func newLokiAPI(client *http.Client, url string, log log.Logger, oauthToken string) *LokiAPI { + return &LokiAPI{client: client, url: url, log: log, oauthToken: oauthToken} } -func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*http.Request, error) { +func addOauthHeader(req *http.Request, oauthToken string) { + if oauthToken != "" { + req.Header.Set("Authorization", oauthToken) + } +} + +func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, oauthToken string) (*http.Request, error) { qs := url.Values{} qs.Set("query", query.Expr) @@ -78,12 +85,7 @@ func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*h return nil, err } - // NOTE: - // 1. we are missing "dynamic" http params, like OAuth data. - // this never worked before (and it is not needed for alerting scenarios), - // so it is not a regression. - // twe need to have that when we migrate to backend-queries. - // + addOauthHeader(req, oauthToken) if query.VolumeQuery { req.Header.Set("X-Query-Tags", "Source=logvolhist") @@ -136,7 +138,7 @@ func makeLokiError(body io.ReadCloser) error { } func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (*loghttp.QueryResponse, error) { - req, err := makeDataRequest(ctx, api.url, query) + req, err := makeDataRequest(ctx, api.url, query, api.oauthToken) if err != nil { return nil, err } @@ -165,7 +167,7 @@ func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (*loghttp.Qu return &response, nil } -func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string) (*http.Request, error) { +func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string, oauthToken string) (*http.Request, error) { lokiUrl, err := url.Parse(lokiDsUrl) if err != nil { return nil, err @@ -176,11 +178,19 @@ func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string) ( return nil, err } - return http.NewRequestWithContext(ctx, "GET", url.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil) + + if err != nil { + return nil, err + } + + addOauthHeader(req, oauthToken) + + return req, nil } func (api *LokiAPI) RawQuery(ctx context.Context, resourceURL string) ([]byte, error) { - req, err := makeRawRequest(ctx, api.url, resourceURL) + req, err := makeRawRequest(ctx, api.url, resourceURL, api.oauthToken) if err != nil { return nil, err } diff --git a/pkg/tsdb/loki/api_mock.go b/pkg/tsdb/loki/api_mock.go new file mode 100644 index 00000000000..a48633677e7 --- /dev/null +++ b/pkg/tsdb/loki/api_mock.go @@ -0,0 +1,41 @@ +package loki + +import ( + "bytes" + "io" + "net/http" + + "github.com/grafana/grafana/pkg/infra/log" +) + +type mockRequestCallback func(req *http.Request) + +type mockedRoundTripper struct { + statusCode int + responseBytes []byte + contentType string + requestCallback mockRequestCallback +} + +func (mockedRT *mockedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + requestCallback := mockedRT.requestCallback + if requestCallback != nil { + requestCallback(req) + } + + header := http.Header{} + header.Add("Content-Type", mockedRT.contentType) + return &http.Response{ + StatusCode: mockedRT.statusCode, + Header: header, + Body: io.NopCloser(bytes.NewReader(mockedRT.responseBytes)), + }, nil +} + +func makeMockedAPI(statusCode int, contentType string, responseBytes []byte, requestCallback mockRequestCallback) *LokiAPI { + client := http.Client{ + Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback}, + } + + return newLokiAPI(&client, "http://localhost:9999", log.New("test"), "") +} diff --git a/pkg/tsdb/loki/framing_test.go b/pkg/tsdb/loki/framing_test.go index 8525a7f0673..187595484bf 100644 --- a/pkg/tsdb/loki/framing_test.go +++ b/pkg/tsdb/loki/framing_test.go @@ -1,9 +1,7 @@ package loki import ( - "bytes" "context" - "io/ioutil" "net/http" "os" "path/filepath" @@ -12,7 +10,6 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/experimental" - "github.com/grafana/grafana/pkg/infra/log" "github.com/stretchr/testify/require" ) @@ -59,7 +56,7 @@ func TestSuccessResponse(t *testing.T) { bytes, err := os.ReadFile(responseFileName) require.NoError(t, err) - frames, err := runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes), &test.query) + frames, err := runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes, nil), &test.query) require.NoError(t, err) dr := &backend.DataResponse{ @@ -119,7 +116,7 @@ func TestErrorResponse(t *testing.T) { for _, test := range tt { t.Run(test.name, func(t *testing.T) { - frames, err := runQuery(context.Background(), makeMockedAPI(400, test.contentType, test.body), &lokiQuery{QueryType: QueryTypeRange, Direction: DirectionBackward}) + frames, err := runQuery(context.Background(), makeMockedAPI(400, test.contentType, test.body, nil), &lokiQuery{QueryType: QueryTypeRange, Direction: DirectionBackward}) require.Len(t, frames, 0) require.Error(t, err) @@ -127,27 +124,3 @@ func TestErrorResponse(t *testing.T) { }) } } - -type mockedRoundTripper struct { - statusCode int - responseBytes []byte - contentType string -} - -func (mockedRT *mockedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - header := http.Header{} - header.Add("Content-Type", mockedRT.contentType) - return &http.Response{ - StatusCode: mockedRT.statusCode, - Header: header, - Body: ioutil.NopCloser(bytes.NewReader(mockedRT.responseBytes)), - }, nil -} - -func makeMockedAPI(statusCode int, contentType string, responseBytes []byte) *LokiAPI { - client := http.Client{ - Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes}, - } - - return newLokiAPI(&client, "http://localhost:9999", log.New("test")) -} diff --git a/pkg/tsdb/loki/loki.go b/pkg/tsdb/loki/loki.go index 4404d86d961..6a274296f69 100644 --- a/pkg/tsdb/loki/loki.go +++ b/pkg/tsdb/loki/loki.go @@ -44,8 +44,9 @@ var ( ) type datasourceInfo struct { - HTTPClient *http.Client - URL string + HTTPClient *http.Client + URL string + OauthPassThru bool // open streams streams map[string]data.FrameJSONCache @@ -64,6 +65,10 @@ type QueryJSONModel struct { VolumeQuery bool `json:"volumeQuery"` } +type DataSourceJSONModel struct { + OauthPassThru bool `json:"oauthPassThru"` +} + func parseQueryModel(raw json.RawMessage) (*QueryJSONModel, error) { model := &QueryJSONModel{} err := json.Unmarshal(raw, model) @@ -82,16 +87,54 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst return nil, err } + jsonModel := DataSourceJSONModel{} + err = json.Unmarshal(settings.JSONData, &jsonModel) + if err != nil { + return nil, err + } + model := &datasourceInfo{ - HTTPClient: client, - URL: settings.URL, - streams: make(map[string]data.FrameJSONCache), + HTTPClient: client, + URL: settings.URL, + OauthPassThru: jsonModel.OauthPassThru, + streams: make(map[string]data.FrameJSONCache), } return model, nil } } +func getOauthTokenForQueryData(dsInfo *datasourceInfo, headers map[string]string) string { + if !dsInfo.OauthPassThru { + return "" + } + + return headers["Authorization"] +} + +func getOauthTokenForCallResource(dsInfo *datasourceInfo, headers map[string][]string) string { + if !dsInfo.OauthPassThru { + return "" + } + + accessValues := headers["Authorization"] + + if len(accessValues) == 0 { + return "" + } + + return accessValues[0] +} + func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { + dsInfo, err := s.getDSInfo(req.PluginContext) + if err != nil { + return err + } + + return callResource(ctx, req, sender, dsInfo, s.plog) +} + +func callResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, dsInfo *datasourceInfo, plog log.Logger) error { url := req.URL // a very basic is-this-url-valid check @@ -105,12 +148,7 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq } lokiURL := fmt.Sprintf("/loki/api/v1/%s", url) - dsInfo, err := s.getDSInfo(req.PluginContext) - if err != nil { - return err - } - - api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, s.plog) + api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getOauthTokenForCallResource(dsInfo, req.Headers)) bytes, err := api.RawQuery(ctx, lokiURL) if err != nil { @@ -127,14 +165,19 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq } func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - result := backend.NewQueryDataResponse() - dsInfo, err := s.getDSInfo(req.PluginContext) if err != nil { + result := backend.NewQueryDataResponse() return result, err } - api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, s.plog) + return queryData(ctx, req, dsInfo, s.plog, s.tracer) +} + +func queryData(ctx context.Context, req *backend.QueryDataRequest, dsInfo *datasourceInfo, plog log.Logger, tracer tracing.Tracer) (*backend.QueryDataResponse, error) { + result := backend.NewQueryDataResponse() + + api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getOauthTokenForQueryData(dsInfo, req.Headers)) queries, err := parseQuery(req) if err != nil { @@ -142,8 +185,8 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) } for _, query := range queries { - s.plog.Debug("Sending query", "start", query.Start, "end", query.End, "step", query.Step, "query", query.Expr) - _, span := s.tracer.Start(ctx, "alerting.loki") + plog.Debug("Sending query", "start", query.Start, "end", query.End, "step", query.Step, "query", query.Expr) + _, span := tracer.Start(ctx, "alerting.loki") span.SetAttributes("expr", query.Expr, attribute.Key("expr").String(query.Expr)) span.SetAttributes("start_unixnano", query.Start, attribute.Key("start_unixnano").Int64(query.Start.UnixNano())) span.SetAttributes("stop_unixnano", query.End, attribute.Key("stop_unixnano").Int64(query.End.UnixNano())) diff --git a/pkg/tsdb/loki/loki_bench_test.go b/pkg/tsdb/loki/loki_bench_test.go index de0b6d50a17..05e4b6f9273 100644 --- a/pkg/tsdb/loki/loki_bench_test.go +++ b/pkg/tsdb/loki/loki_bench_test.go @@ -17,7 +17,7 @@ func BenchmarkMatrixJson(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - _, _ = runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes), &lokiQuery{}) + _, _ = runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes, nil), &lokiQuery{}) } } diff --git a/pkg/tsdb/loki/oauth_test.go b/pkg/tsdb/loki/oauth_test.go new file mode 100644 index 00000000000..641c29cc59c --- /dev/null +++ b/pkg/tsdb/loki/oauth_test.go @@ -0,0 +1,164 @@ +package loki + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/stretchr/testify/require" +) + +type mockedRoundTripperForOauth struct { + requestCallback func(req *http.Request) + body []byte +} + +func (mockedRT *mockedRoundTripperForOauth) RoundTrip(req *http.Request) (*http.Response, error) { + mockedRT.requestCallback(req) + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(mockedRT.body)), + }, nil +} + +type mockedCallResourceResponseSenderForOauth struct { + Response *backend.CallResourceResponse +} + +func (s *mockedCallResourceResponseSenderForOauth) Send(resp *backend.CallResourceResponse) error { + s.Response = resp + return nil +} + +func makeMockedDsInfoForOauth(oauthPassThru bool, body []byte, requestCallback func(req *http.Request)) datasourceInfo { + client := http.Client{ + Transport: &mockedRoundTripperForOauth{requestCallback: requestCallback, body: body}, + } + + return datasourceInfo{ + HTTPClient: &client, + OauthPassThru: oauthPassThru, + } +} + +func TestOauthForwardIdentity(t *testing.T) { + tt := []struct { + name string + oauthPassThru bool + headerGiven bool + headerSent bool + }{ + {name: "when enabled and headers exist => add headers", oauthPassThru: true, headerGiven: true, headerSent: true}, + {name: "when disabled and headers exist => do not add headers", oauthPassThru: false, headerGiven: true, headerSent: false}, + {name: "when enabled and no headers exist => do not add headers", oauthPassThru: true, headerGiven: false, headerSent: false}, + {name: "when disabled and no headers exist => do not add headers", oauthPassThru: false, headerGiven: false, headerSent: false}, + } + + authName := "Authorization" + authValue := "auth" + + for _, test := range tt { + t.Run("QueryData: "+test.name, func(t *testing.T) { + response := []byte(` + { + "status": "success", + "data": { + "resultType": "streams", + "result": [ + { + "stream": {}, + "values": [ + ["1", "line1"] + ] + } + ] + } + } + `) + + clientUsed := false + dsInfo := makeMockedDsInfoForOauth(test.oauthPassThru, response, func(req *http.Request) { + clientUsed = true + if test.headerSent { + require.Equal(t, authValue, req.Header.Get(authName)) + } else { + require.Equal(t, "", req.Header.Get(authName)) + } + }) + + req := backend.QueryDataRequest{ + Headers: map[string]string{}, + Queries: []backend.DataQuery{ + { + RefID: "A", + JSON: []byte("{}"), + }, + }, + } + + if test.headerGiven { + req.Headers[authName] = authValue + } + + tracer, err := tracing.InitializeTracerForTest() + require.NoError(t, err) + + data, err := queryData(context.Background(), &req, &dsInfo, log.New("testlog"), tracer) + // we do a basic check that the result is OK + require.NoError(t, err) + require.Len(t, data.Responses, 1) + res := data.Responses["A"] + require.NoError(t, res.Error) + require.Len(t, res.Frames, 1) + require.Equal(t, "line1", res.Frames[0].Fields[2].At(0)) + + // we need to be sure the client-callback was triggered + require.True(t, clientUsed) + }) + } + + for _, test := range tt { + t.Run("CallResource: "+test.name, func(t *testing.T) { + response := []byte("mocked resource response") + + clientUsed := false + dsInfo := makeMockedDsInfoForOauth(test.oauthPassThru, response, func(req *http.Request) { + clientUsed = true + if test.headerSent { + require.Equal(t, authValue, req.Header.Get(authName)) + } else { + require.Equal(t, "", req.Header.Get(authName)) + } + }) + + req := backend.CallResourceRequest{ + Headers: map[string][]string{}, + Method: "GET", + URL: "labels?", + } + + if test.headerGiven { + req.Headers[authName] = []string{authValue} + } + + sender := &mockedCallResourceResponseSenderForOauth{} + + err := callResource(context.Background(), &req, sender, &dsInfo, log.New("testlog")) + // we do a basic check that the result is OK + require.NoError(t, err) + sent := sender.Response + require.NotNil(t, sent) + require.Equal(t, http.StatusOK, sent.Status) + require.Equal(t, response, sent.Body) + + // we need to be sure the client-callback was triggered + require.True(t, clientUsed) + }) + } +}