loki backend mode forward-oauth (#48401)

* loki: backend: add forward oauth credentials functionality

* removed obsolete comment
This commit is contained in:
Gábor Farkas 2022-05-05 12:42:50 +02:00 committed by GitHub
parent 9b8cdab123
commit 02aa1cd1c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 293 additions and 62 deletions

View File

@ -16,16 +16,23 @@ import (
)
type LokiAPI struct {
client *http.Client
url string
log log.Logger
client *http.Client
url string
log log.Logger
oauthToken string
}
func newLokiAPI(client *http.Client, url string, log log.Logger) *LokiAPI {
return &LokiAPI{client: client, url: url, log: log}
func newLokiAPI(client *http.Client, url string, log log.Logger, oauthToken string) *LokiAPI {
return &LokiAPI{client: client, url: url, log: log, oauthToken: oauthToken}
}
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*http.Request, error) {
func addOauthHeader(req *http.Request, oauthToken string) {
if oauthToken != "" {
req.Header.Set("Authorization", oauthToken)
}
}
func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery, oauthToken string) (*http.Request, error) {
qs := url.Values{}
qs.Set("query", query.Expr)
@ -78,12 +85,7 @@ func makeDataRequest(ctx context.Context, lokiDsUrl string, query lokiQuery) (*h
return nil, err
}
// NOTE:
// 1. we are missing "dynamic" http params, like OAuth data.
// this never worked before (and it is not needed for alerting scenarios),
// so it is not a regression.
// twe need to have that when we migrate to backend-queries.
//
addOauthHeader(req, oauthToken)
if query.VolumeQuery {
req.Header.Set("X-Query-Tags", "Source=logvolhist")
@ -136,7 +138,7 @@ func makeLokiError(body io.ReadCloser) error {
}
func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (*loghttp.QueryResponse, error) {
req, err := makeDataRequest(ctx, api.url, query)
req, err := makeDataRequest(ctx, api.url, query, api.oauthToken)
if err != nil {
return nil, err
}
@ -165,7 +167,7 @@ func (api *LokiAPI) DataQuery(ctx context.Context, query lokiQuery) (*loghttp.Qu
return &response, nil
}
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string) (*http.Request, error) {
func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string, oauthToken string) (*http.Request, error) {
lokiUrl, err := url.Parse(lokiDsUrl)
if err != nil {
return nil, err
@ -176,11 +178,19 @@ func makeRawRequest(ctx context.Context, lokiDsUrl string, resourceURL string) (
return nil, err
}
return http.NewRequestWithContext(ctx, "GET", url.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil)
if err != nil {
return nil, err
}
addOauthHeader(req, oauthToken)
return req, nil
}
func (api *LokiAPI) RawQuery(ctx context.Context, resourceURL string) ([]byte, error) {
req, err := makeRawRequest(ctx, api.url, resourceURL)
req, err := makeRawRequest(ctx, api.url, resourceURL, api.oauthToken)
if err != nil {
return nil, err
}

41
pkg/tsdb/loki/api_mock.go Normal file
View File

@ -0,0 +1,41 @@
package loki
import (
"bytes"
"io"
"net/http"
"github.com/grafana/grafana/pkg/infra/log"
)
type mockRequestCallback func(req *http.Request)
type mockedRoundTripper struct {
statusCode int
responseBytes []byte
contentType string
requestCallback mockRequestCallback
}
func (mockedRT *mockedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
requestCallback := mockedRT.requestCallback
if requestCallback != nil {
requestCallback(req)
}
header := http.Header{}
header.Add("Content-Type", mockedRT.contentType)
return &http.Response{
StatusCode: mockedRT.statusCode,
Header: header,
Body: io.NopCloser(bytes.NewReader(mockedRT.responseBytes)),
}, nil
}
func makeMockedAPI(statusCode int, contentType string, responseBytes []byte, requestCallback mockRequestCallback) *LokiAPI {
client := http.Client{
Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes, requestCallback: requestCallback},
}
return newLokiAPI(&client, "http://localhost:9999", log.New("test"), "")
}

View File

@ -1,9 +1,7 @@
package loki
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"os"
"path/filepath"
@ -12,7 +10,6 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/experimental"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/stretchr/testify/require"
)
@ -59,7 +56,7 @@ func TestSuccessResponse(t *testing.T) {
bytes, err := os.ReadFile(responseFileName)
require.NoError(t, err)
frames, err := runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes), &test.query)
frames, err := runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes, nil), &test.query)
require.NoError(t, err)
dr := &backend.DataResponse{
@ -119,7 +116,7 @@ func TestErrorResponse(t *testing.T) {
for _, test := range tt {
t.Run(test.name, func(t *testing.T) {
frames, err := runQuery(context.Background(), makeMockedAPI(400, test.contentType, test.body), &lokiQuery{QueryType: QueryTypeRange, Direction: DirectionBackward})
frames, err := runQuery(context.Background(), makeMockedAPI(400, test.contentType, test.body, nil), &lokiQuery{QueryType: QueryTypeRange, Direction: DirectionBackward})
require.Len(t, frames, 0)
require.Error(t, err)
@ -127,27 +124,3 @@ func TestErrorResponse(t *testing.T) {
})
}
}
type mockedRoundTripper struct {
statusCode int
responseBytes []byte
contentType string
}
func (mockedRT *mockedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
header := http.Header{}
header.Add("Content-Type", mockedRT.contentType)
return &http.Response{
StatusCode: mockedRT.statusCode,
Header: header,
Body: ioutil.NopCloser(bytes.NewReader(mockedRT.responseBytes)),
}, nil
}
func makeMockedAPI(statusCode int, contentType string, responseBytes []byte) *LokiAPI {
client := http.Client{
Transport: &mockedRoundTripper{statusCode: statusCode, contentType: contentType, responseBytes: responseBytes},
}
return newLokiAPI(&client, "http://localhost:9999", log.New("test"))
}

View File

@ -44,8 +44,9 @@ var (
)
type datasourceInfo struct {
HTTPClient *http.Client
URL string
HTTPClient *http.Client
URL string
OauthPassThru bool
// open streams
streams map[string]data.FrameJSONCache
@ -64,6 +65,10 @@ type QueryJSONModel struct {
VolumeQuery bool `json:"volumeQuery"`
}
type DataSourceJSONModel struct {
OauthPassThru bool `json:"oauthPassThru"`
}
func parseQueryModel(raw json.RawMessage) (*QueryJSONModel, error) {
model := &QueryJSONModel{}
err := json.Unmarshal(raw, model)
@ -82,16 +87,54 @@ func newInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
return nil, err
}
jsonModel := DataSourceJSONModel{}
err = json.Unmarshal(settings.JSONData, &jsonModel)
if err != nil {
return nil, err
}
model := &datasourceInfo{
HTTPClient: client,
URL: settings.URL,
streams: make(map[string]data.FrameJSONCache),
HTTPClient: client,
URL: settings.URL,
OauthPassThru: jsonModel.OauthPassThru,
streams: make(map[string]data.FrameJSONCache),
}
return model, nil
}
}
func getOauthTokenForQueryData(dsInfo *datasourceInfo, headers map[string]string) string {
if !dsInfo.OauthPassThru {
return ""
}
return headers["Authorization"]
}
func getOauthTokenForCallResource(dsInfo *datasourceInfo, headers map[string][]string) string {
if !dsInfo.OauthPassThru {
return ""
}
accessValues := headers["Authorization"]
if len(accessValues) == 0 {
return ""
}
return accessValues[0]
}
func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error {
dsInfo, err := s.getDSInfo(req.PluginContext)
if err != nil {
return err
}
return callResource(ctx, req, sender, dsInfo, s.plog)
}
func callResource(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender, dsInfo *datasourceInfo, plog log.Logger) error {
url := req.URL
// a very basic is-this-url-valid check
@ -105,12 +148,7 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq
}
lokiURL := fmt.Sprintf("/loki/api/v1/%s", url)
dsInfo, err := s.getDSInfo(req.PluginContext)
if err != nil {
return err
}
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, s.plog)
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getOauthTokenForCallResource(dsInfo, req.Headers))
bytes, err := api.RawQuery(ctx, lokiURL)
if err != nil {
@ -127,14 +165,19 @@ func (s *Service) CallResource(ctx context.Context, req *backend.CallResourceReq
}
func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
result := backend.NewQueryDataResponse()
dsInfo, err := s.getDSInfo(req.PluginContext)
if err != nil {
result := backend.NewQueryDataResponse()
return result, err
}
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, s.plog)
return queryData(ctx, req, dsInfo, s.plog, s.tracer)
}
func queryData(ctx context.Context, req *backend.QueryDataRequest, dsInfo *datasourceInfo, plog log.Logger, tracer tracing.Tracer) (*backend.QueryDataResponse, error) {
result := backend.NewQueryDataResponse()
api := newLokiAPI(dsInfo.HTTPClient, dsInfo.URL, plog, getOauthTokenForQueryData(dsInfo, req.Headers))
queries, err := parseQuery(req)
if err != nil {
@ -142,8 +185,8 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
}
for _, query := range queries {
s.plog.Debug("Sending query", "start", query.Start, "end", query.End, "step", query.Step, "query", query.Expr)
_, span := s.tracer.Start(ctx, "alerting.loki")
plog.Debug("Sending query", "start", query.Start, "end", query.End, "step", query.Step, "query", query.Expr)
_, span := tracer.Start(ctx, "alerting.loki")
span.SetAttributes("expr", query.Expr, attribute.Key("expr").String(query.Expr))
span.SetAttributes("start_unixnano", query.Start, attribute.Key("start_unixnano").Int64(query.Start.UnixNano()))
span.SetAttributes("stop_unixnano", query.End, attribute.Key("stop_unixnano").Int64(query.End.UnixNano()))

View File

@ -17,7 +17,7 @@ func BenchmarkMatrixJson(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_, _ = runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes), &lokiQuery{})
_, _ = runQuery(context.Background(), makeMockedAPI(http.StatusOK, "application/json", bytes, nil), &lokiQuery{})
}
}

164
pkg/tsdb/loki/oauth_test.go Normal file
View File

@ -0,0 +1,164 @@
package loki
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/stretchr/testify/require"
)
type mockedRoundTripperForOauth struct {
requestCallback func(req *http.Request)
body []byte
}
func (mockedRT *mockedRoundTripperForOauth) RoundTrip(req *http.Request) (*http.Response, error) {
mockedRT.requestCallback(req)
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(mockedRT.body)),
}, nil
}
type mockedCallResourceResponseSenderForOauth struct {
Response *backend.CallResourceResponse
}
func (s *mockedCallResourceResponseSenderForOauth) Send(resp *backend.CallResourceResponse) error {
s.Response = resp
return nil
}
func makeMockedDsInfoForOauth(oauthPassThru bool, body []byte, requestCallback func(req *http.Request)) datasourceInfo {
client := http.Client{
Transport: &mockedRoundTripperForOauth{requestCallback: requestCallback, body: body},
}
return datasourceInfo{
HTTPClient: &client,
OauthPassThru: oauthPassThru,
}
}
func TestOauthForwardIdentity(t *testing.T) {
tt := []struct {
name string
oauthPassThru bool
headerGiven bool
headerSent bool
}{
{name: "when enabled and headers exist => add headers", oauthPassThru: true, headerGiven: true, headerSent: true},
{name: "when disabled and headers exist => do not add headers", oauthPassThru: false, headerGiven: true, headerSent: false},
{name: "when enabled and no headers exist => do not add headers", oauthPassThru: true, headerGiven: false, headerSent: false},
{name: "when disabled and no headers exist => do not add headers", oauthPassThru: false, headerGiven: false, headerSent: false},
}
authName := "Authorization"
authValue := "auth"
for _, test := range tt {
t.Run("QueryData: "+test.name, func(t *testing.T) {
response := []byte(`
{
"status": "success",
"data": {
"resultType": "streams",
"result": [
{
"stream": {},
"values": [
["1", "line1"]
]
}
]
}
}
`)
clientUsed := false
dsInfo := makeMockedDsInfoForOauth(test.oauthPassThru, response, func(req *http.Request) {
clientUsed = true
if test.headerSent {
require.Equal(t, authValue, req.Header.Get(authName))
} else {
require.Equal(t, "", req.Header.Get(authName))
}
})
req := backend.QueryDataRequest{
Headers: map[string]string{},
Queries: []backend.DataQuery{
{
RefID: "A",
JSON: []byte("{}"),
},
},
}
if test.headerGiven {
req.Headers[authName] = authValue
}
tracer, err := tracing.InitializeTracerForTest()
require.NoError(t, err)
data, err := queryData(context.Background(), &req, &dsInfo, log.New("testlog"), tracer)
// we do a basic check that the result is OK
require.NoError(t, err)
require.Len(t, data.Responses, 1)
res := data.Responses["A"]
require.NoError(t, res.Error)
require.Len(t, res.Frames, 1)
require.Equal(t, "line1", res.Frames[0].Fields[2].At(0))
// we need to be sure the client-callback was triggered
require.True(t, clientUsed)
})
}
for _, test := range tt {
t.Run("CallResource: "+test.name, func(t *testing.T) {
response := []byte("mocked resource response")
clientUsed := false
dsInfo := makeMockedDsInfoForOauth(test.oauthPassThru, response, func(req *http.Request) {
clientUsed = true
if test.headerSent {
require.Equal(t, authValue, req.Header.Get(authName))
} else {
require.Equal(t, "", req.Header.Get(authName))
}
})
req := backend.CallResourceRequest{
Headers: map[string][]string{},
Method: "GET",
URL: "labels?",
}
if test.headerGiven {
req.Headers[authName] = []string{authValue}
}
sender := &mockedCallResourceResponseSenderForOauth{}
err := callResource(context.Background(), &req, sender, &dsInfo, log.New("testlog"))
// we do a basic check that the result is OK
require.NoError(t, err)
sent := sender.Response
require.NotNil(t, sent)
require.Equal(t, http.StatusOK, sent.Status)
require.Equal(t, response, sent.Body)
// we need to be sure the client-callback was triggered
require.True(t, clientUsed)
})
}
}