mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AzureMonitor: Use auth middleware for QueryData requests (#35343)
This commit is contained in:
committed by
GitHub
parent
36c997a625
commit
7109285ac9
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
@@ -57,7 +58,7 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
|
||||
if tokenProvider, err := getTokenProvider(ctx, cfg, ds, route, data); err != nil {
|
||||
logger.Error("Failed to resolve auth token provider", "error", err)
|
||||
} else if tokenProvider != nil {
|
||||
if token, err := tokenProvider.getAccessToken(); err != nil {
|
||||
if token, err := tokenProvider.GetAccessToken(); err != nil {
|
||||
logger.Error("Failed to get access token", "error", err)
|
||||
} else {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
@@ -90,7 +91,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
|
||||
if tokenAuth == nil {
|
||||
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
|
||||
}
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, pluginRoute, tokenAuth)
|
||||
provider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
return provider, nil
|
||||
|
||||
case "gce":
|
||||
|
@@ -3,7 +3,7 @@ package pluginproxy
|
||||
import "time"
|
||||
|
||||
type accessTokenProvider interface {
|
||||
getAccessToken() (string, error)
|
||||
GetAccessToken() (string, error)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@@ -27,7 +27,7 @@ func newGceAccessTokenProvider(ctx context.Context, ds *models.DataSource, plugi
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *gceAccessTokenProvider) getAccessToken() (string, error) {
|
||||
func (provider *gceAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
tokenSrc, err := google.DefaultTokenSource(provider.ctx, provider.authParams.Scopes...)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default token from meta data server", "error", err)
|
||||
|
@@ -78,7 +78,7 @@ func newGenericAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.A
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *genericAccessTokenProvider) getAccessToken() (string, error) {
|
||||
func (provider *genericAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
tokenCache.Lock()
|
||||
defer tokenCache.Unlock()
|
||||
if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
|
||||
|
@@ -42,7 +42,7 @@ func newJwtAccessTokenProvider(ctx context.Context, ds *models.DataSource, plugi
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *jwtAccessTokenProvider) getAccessToken() (string, error) {
|
||||
func (provider *jwtAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
oauthJwtTokenCache.Lock()
|
||||
defer oauthJwtTokenCache.Unlock()
|
||||
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
|
||||
|
@@ -70,7 +70,7 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
return &oauth2.Token{AccessToken: "abc"}, nil
|
||||
})
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
|
||||
token, err := provider.getAccessToken()
|
||||
token, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc", token)
|
||||
@@ -89,7 +89,7 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
})
|
||||
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
|
||||
_, err := provider.getAccessToken()
|
||||
_, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -100,14 +100,14 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
Expiry: time.Now().Add(1 * time.Minute)}, nil
|
||||
})
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
|
||||
token1, err := provider.getAccessToken()
|
||||
token1, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "abc", token1)
|
||||
|
||||
getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{AccessToken: "error: cache not used"}, nil
|
||||
}
|
||||
token2, err := provider.getAccessToken()
|
||||
token2, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "abc", token2)
|
||||
})
|
||||
@@ -224,12 +224,12 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
|
||||
token["expires_on"] = testCase.expiresOn
|
||||
}
|
||||
|
||||
accessToken, err := provider.getAccessToken()
|
||||
accessToken, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
|
||||
// getAccessToken should use internal cache
|
||||
accessToken, err = provider.getAccessToken()
|
||||
// GetAccessToken should use internal cache
|
||||
accessToken, err = provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
assert.Equal(t, 1, authCalls)
|
||||
@@ -259,13 +259,13 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
|
||||
"token_type": "3600",
|
||||
"refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",
|
||||
}
|
||||
accessToken, err := provider.getAccessToken()
|
||||
accessToken, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
|
||||
mockTimeNow(timeNow().Add(3601 * time.Second))
|
||||
|
||||
accessToken, err = provider.getAccessToken()
|
||||
accessToken, err = provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
assert.Equal(t, 2, authCalls)
|
||||
|
@@ -3,7 +3,6 @@ package azuremonitor
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -15,22 +14,13 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/grafana/grafana/pkg/api/pluginproxy"
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
)
|
||||
|
||||
// ApplicationInsightsDatasource calls the application insights query API.
|
||||
type ApplicationInsightsDatasource struct {
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
type ApplicationInsightsDatasource struct{}
|
||||
|
||||
// ApplicationInsightsQuery is the model that holds the information
|
||||
// needed to make a metrics query to Application Insights, and the information
|
||||
@@ -164,7 +154,7 @@ func (e *ApplicationInsightsDatasource) executeQuery(ctx context.Context, query
|
||||
}
|
||||
|
||||
azlog.Debug("ApplicationInsights", "Request URL", req.URL.String())
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.Services[appInsights].HTTPClient, req)
|
||||
if err != nil {
|
||||
dataResponse.Error = err
|
||||
return dataResponse, nil
|
||||
@@ -204,63 +194,20 @@ func (e *ApplicationInsightsDatasource) executeQuery(ctx context.Context, query
|
||||
}
|
||||
|
||||
func (e *ApplicationInsightsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
|
||||
// find plugin
|
||||
plugin := e.pluginManager.GetDataSource(dsName)
|
||||
if plugin == nil {
|
||||
return nil, errors.New("unable to find datasource plugin Azure Application Insights")
|
||||
}
|
||||
|
||||
appInsightsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appInsightsAppID := dsInfo.Settings.AppInsightsAppId
|
||||
|
||||
u, err := url.Parse(dsInfo.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, fmt.Sprintf("/v1/apps/%s", appInsightsAppID))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[appInsights].URL, nil)
|
||||
if err != nil {
|
||||
azlog.Debug("Failed to create request", "error", err)
|
||||
return nil, errutil.Wrap("Failed to create request", err)
|
||||
}
|
||||
req.Header.Set("X-API-Key", dsInfo.DecryptedSecureJSONData["appInsightsApiKey"])
|
||||
|
||||
// TODO: Use backend authentication instead
|
||||
proxyPass := fmt.Sprintf("%s/v1/apps/%s", routeName, appInsightsAppID)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, &models.DataSource{
|
||||
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
|
||||
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
|
||||
}, e.cfg)
|
||||
req.URL.Path = fmt.Sprintf("/v1/apps/%s", appInsightsAppID)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (e *ApplicationInsightsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
|
||||
cloud, err := getAzureCloud(e.cfg, dsInfo)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
routeName, err := getAppInsightsApiRoute(cloud)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var pluginRoute *plugins.AppPluginRoute
|
||||
for _, route := range plugin.Routes {
|
||||
if route.Path == routeName {
|
||||
pluginRoute = route
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return pluginRoute, routeName, nil
|
||||
}
|
||||
|
||||
// formatApplicationInsightsLegendKey builds the legend key or timeseries name
|
||||
// Alias patterns like {{metric}} are replaced with the appropriate data values.
|
||||
func formatApplicationInsightsLegendKey(alias string, metricName string, labels data.Labels) string {
|
||||
|
@@ -1,15 +1,14 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
@@ -159,92 +158,6 @@ func TestApplicationInsightsDatasource(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppInsightsPluginRoutes(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzurePublic,
|
||||
ManagedIdentityEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
plugin := &plugins.DataSourcePlugin{
|
||||
Routes: []*plugins.AppPluginRoute{
|
||||
{
|
||||
Path: "appinsights",
|
||||
Method: "GET",
|
||||
URL: "https://api.applicationinsights.io",
|
||||
Headers: []plugins.AppPluginRouteHeader{
|
||||
{Name: "X-API-Key", Content: "{{.SecureJsonData.appInsightsApiKey}}"},
|
||||
{Name: "x-ms-app", Content: "Grafana"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: "chinaappinsights",
|
||||
Method: "GET",
|
||||
URL: "https://api.applicationinsights.azure.cn",
|
||||
Headers: []plugins.AppPluginRouteHeader{
|
||||
{Name: "X-API-Key", Content: "{{.SecureJsonData.appInsightsApiKey}}"},
|
||||
{Name: "x-ms-app", Content: "Grafana"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
datasource *ApplicationInsightsDatasource
|
||||
dsInfo datasourceInfo
|
||||
expectedRouteName string
|
||||
expectedRouteURL string
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "plugin proxy route for the Azure public cloud",
|
||||
dsInfo: datasourceInfo{
|
||||
Settings: azureMonitorSettings{
|
||||
AzureAuthType: AzureAuthClientSecret,
|
||||
CloudName: "azuremonitor",
|
||||
},
|
||||
},
|
||||
datasource: &ApplicationInsightsDatasource{
|
||||
cfg: cfg,
|
||||
},
|
||||
expectedRouteName: "appinsights",
|
||||
expectedRouteURL: "https://api.applicationinsights.io",
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "plugin proxy route for the Azure China cloud",
|
||||
dsInfo: datasourceInfo{
|
||||
Settings: azureMonitorSettings{
|
||||
AzureAuthType: AzureAuthClientSecret,
|
||||
CloudName: "chinaazuremonitor",
|
||||
},
|
||||
},
|
||||
datasource: &ApplicationInsightsDatasource{
|
||||
cfg: cfg,
|
||||
},
|
||||
expectedRouteName: "chinaappinsights",
|
||||
expectedRouteURL: "https://api.applicationinsights.azure.cn",
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
route, routeName, err := tt.datasource.getPluginRoute(plugin, tt.dsInfo)
|
||||
tt.Err(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.expectedRouteURL, route.URL, cmpopts.EquateNaNs()); diff != "" {
|
||||
t.Errorf("Result mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedRouteName, routeName, cmpopts.EquateNaNs()); diff != "" {
|
||||
t.Errorf("Result mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestInsightsDimensionsUnmarshalJSON(t *testing.T) {
|
||||
a := []byte(`"foo"`)
|
||||
b := []byte(`["foo"]`)
|
||||
@@ -291,3 +204,46 @@ func TestInsightsDimensionsUnmarshalJSON(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, gs)
|
||||
}
|
||||
|
||||
func TestAppInsightsCreateRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsInfo := datasourceInfo{
|
||||
Settings: azureMonitorSettings{AppInsightsAppId: "foo"},
|
||||
Services: map[string]datasourceService{
|
||||
appInsights: {URL: "http://ds"},
|
||||
},
|
||||
DecryptedSecureJSONData: map[string]string{
|
||||
"appInsightsApiKey": "key",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedURL string
|
||||
expectedHeaders http.Header
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates a request",
|
||||
expectedURL: "http://ds/v1/apps/foo",
|
||||
expectedHeaders: http.Header{
|
||||
"X-Api-Key": []string{"key"},
|
||||
},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ds := ApplicationInsightsDatasource{}
|
||||
req, err := ds.createRequest(ctx, dsInfo)
|
||||
tt.Err(t, err)
|
||||
if req.URL.String() != tt.expectedURL {
|
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
|
||||
}
|
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) {
|
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -16,22 +15,14 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/grafana/grafana/pkg/api/pluginproxy"
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
)
|
||||
|
||||
// AzureLogAnalyticsDatasource calls the Azure Log Analytics API's
|
||||
type AzureLogAnalyticsDatasource struct {
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
type AzureLogAnalyticsDatasource struct{}
|
||||
|
||||
// AzureLogAnalyticsQuery is the query request that is built from the saved values for
|
||||
// from the UI
|
||||
@@ -170,7 +161,7 @@ func (e *AzureLogAnalyticsDatasource) executeQuery(ctx context.Context, query *A
|
||||
}
|
||||
|
||||
azlog.Debug("AzureLogAnalytics", "Request ApiURL", req.URL.String())
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureLogAnalytics].HTTPClient, req)
|
||||
if err != nil {
|
||||
return dataResponseErrorWithExecuted(err)
|
||||
}
|
||||
@@ -220,62 +211,17 @@ func (e *AzureLogAnalyticsDatasource) executeQuery(ctx context.Context, query *A
|
||||
}
|
||||
|
||||
func (e *AzureLogAnalyticsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
|
||||
// find plugin
|
||||
plugin := e.pluginManager.GetDataSource(dsName)
|
||||
if plugin == nil {
|
||||
return nil, errors.New("unable to find datasource plugin Azure Monitor")
|
||||
}
|
||||
|
||||
logAnalyticsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(dsInfo.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "render")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[azureLogAnalytics].URL, nil)
|
||||
if err != nil {
|
||||
azlog.Debug("Failed to create request", "error", err)
|
||||
return nil, errutil.Wrap("failed to create request", err)
|
||||
}
|
||||
|
||||
req.URL.Path = "/"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// TODO: Use backend authentication instead
|
||||
pluginproxy.ApplyRoute(ctx, req, routeName, logAnalyticsRoute, &models.DataSource{
|
||||
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
|
||||
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
|
||||
}, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (e *AzureLogAnalyticsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
|
||||
cloud, err := getAzureCloud(e.cfg, dsInfo)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
routeName, err := getLogAnalyticsApiRoute(cloud)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var pluginRoute *plugins.AppPluginRoute
|
||||
for _, route := range plugin.Routes {
|
||||
if route.Path == routeName {
|
||||
pluginRoute = route
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return pluginRoute, routeName, nil
|
||||
}
|
||||
|
||||
// GetPrimaryResultTable returns the first table in the response named "PrimaryResult", or an
|
||||
// error if there is no table by that name.
|
||||
func (ar *AzureLogAnalyticsResponse) GetPrimaryResultTable() (*AzureResponseTable, error) {
|
||||
|
@@ -1,16 +1,15 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -179,109 +178,38 @@ func TestBuildingAzureLogAnalyticsQueries(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginRoutes(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzurePublic,
|
||||
ManagedIdentityEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
plugin := &plugins.DataSourcePlugin{
|
||||
Routes: []*plugins.AppPluginRoute{
|
||||
{
|
||||
Path: "loganalyticsazure",
|
||||
Method: "GET",
|
||||
URL: "https://api.loganalytics.io/",
|
||||
Headers: []plugins.AppPluginRouteHeader{
|
||||
{Name: "x-ms-app", Content: "Grafana"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: "chinaloganalyticsazure",
|
||||
Method: "GET",
|
||||
URL: "https://api.loganalytics.azure.cn/",
|
||||
Headers: []plugins.AppPluginRouteHeader{
|
||||
{Name: "x-ms-app", Content: "Grafana"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: "govloganalyticsazure",
|
||||
Method: "GET",
|
||||
URL: "https://api.loganalytics.us/",
|
||||
Headers: []plugins.AppPluginRouteHeader{
|
||||
{Name: "x-ms-app", Content: "Grafana"},
|
||||
},
|
||||
},
|
||||
func TestLogAnalyticsCreateRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsInfo := datasourceInfo{
|
||||
Services: map[string]datasourceService{
|
||||
azureLogAnalytics: {URL: "http://ds"},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dsInfo datasourceInfo
|
||||
datasource *AzureLogAnalyticsDatasource
|
||||
expectedProxypass string
|
||||
expectedRouteURL string
|
||||
Err require.ErrorAssertionFunc
|
||||
name string
|
||||
expectedURL string
|
||||
expectedHeaders http.Header
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "plugin proxy route for the Azure public cloud",
|
||||
dsInfo: datasourceInfo{
|
||||
Settings: azureMonitorSettings{
|
||||
AzureAuthType: AzureAuthClientSecret,
|
||||
CloudName: "azuremonitor",
|
||||
},
|
||||
},
|
||||
datasource: &AzureLogAnalyticsDatasource{
|
||||
cfg: cfg,
|
||||
},
|
||||
expectedProxypass: "loganalyticsazure",
|
||||
expectedRouteURL: "https://api.loganalytics.io/",
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "plugin proxy route for the Azure China cloud",
|
||||
dsInfo: datasourceInfo{
|
||||
Settings: azureMonitorSettings{
|
||||
AzureAuthType: AzureAuthClientSecret,
|
||||
CloudName: "chinaazuremonitor",
|
||||
},
|
||||
},
|
||||
datasource: &AzureLogAnalyticsDatasource{
|
||||
cfg: cfg,
|
||||
},
|
||||
expectedProxypass: "chinaloganalyticsazure",
|
||||
expectedRouteURL: "https://api.loganalytics.azure.cn/",
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "plugin proxy route for the Azure Gov cloud",
|
||||
dsInfo: datasourceInfo{
|
||||
Settings: azureMonitorSettings{
|
||||
AzureAuthType: AzureAuthClientSecret,
|
||||
CloudName: "govazuremonitor",
|
||||
},
|
||||
},
|
||||
datasource: &AzureLogAnalyticsDatasource{
|
||||
cfg: cfg,
|
||||
},
|
||||
expectedProxypass: "govloganalyticsazure",
|
||||
expectedRouteURL: "https://api.loganalytics.us/",
|
||||
Err: require.NoError,
|
||||
name: "creates a request",
|
||||
expectedURL: "http://ds/",
|
||||
expectedHeaders: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
route, proxypass, err := tt.datasource.getPluginRoute(plugin, tt.dsInfo)
|
||||
ds := AzureLogAnalyticsDatasource{}
|
||||
req, err := ds.createRequest(ctx, dsInfo)
|
||||
tt.Err(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.expectedRouteURL, route.URL, cmpopts.EquateNaNs()); diff != "" {
|
||||
t.Errorf("Result mismatch (-want +got):\n%s", diff)
|
||||
if req.URL.String() != tt.expectedURL {
|
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedProxypass, proxypass, cmpopts.EquateNaNs()); diff != "" {
|
||||
t.Errorf("Result mismatch (-want +got):\n%s", diff)
|
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) {
|
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -15,11 +14,7 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/grafana/grafana/pkg/api/pluginproxy"
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
@@ -27,10 +22,7 @@ import (
|
||||
)
|
||||
|
||||
// AzureResourceGraphDatasource calls the Azure Resource Graph API's
|
||||
type AzureResourceGraphDatasource struct {
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
type AzureResourceGraphDatasource struct{}
|
||||
|
||||
// AzureResourceGraphQuery is the query request that is built from the saved values for
|
||||
// from the UI
|
||||
@@ -167,7 +159,7 @@ func (e *AzureResourceGraphDatasource) executeQuery(ctx context.Context, query *
|
||||
}
|
||||
|
||||
azlog.Debug("AzureResourceGraph", "Request ApiURL", req.URL.String())
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureResourceGraph].HTTPClient, req)
|
||||
if err != nil {
|
||||
return dataResponseErrorWithExecuted(err)
|
||||
}
|
||||
@@ -191,62 +183,18 @@ func (e *AzureResourceGraphDatasource) executeQuery(ctx context.Context, query *
|
||||
}
|
||||
|
||||
func (e *AzureResourceGraphDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo, reqBody []byte) (*http.Request, error) {
|
||||
// find plugin
|
||||
plugin := e.pluginManager.GetDataSource(dsName)
|
||||
if plugin == nil {
|
||||
return nil, errors.New("unable to find datasource plugin Azure Monitor")
|
||||
}
|
||||
|
||||
argRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(dsInfo.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "render")
|
||||
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(reqBody))
|
||||
req, err := http.NewRequest(http.MethodPost, dsInfo.Services[azureResourceGraph].URL, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
azlog.Debug("Failed to create request", "error", err)
|
||||
return nil, errutil.Wrap("failed to create request", err)
|
||||
}
|
||||
|
||||
req.URL.Path = "/"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
|
||||
|
||||
// TODO: Use backend authentication instead
|
||||
pluginproxy.ApplyRoute(ctx, req, routeName, argRoute, &models.DataSource{
|
||||
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
|
||||
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
|
||||
}, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (e *AzureResourceGraphDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
|
||||
cloud, err := getAzureCloud(e.cfg, dsInfo)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
routeName, err := getManagementApiRoute(cloud)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var pluginRoute *plugins.AppPluginRoute
|
||||
for _, route := range plugin.Routes {
|
||||
if route.Path == routeName {
|
||||
pluginRoute = route
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return pluginRoute, routeName, nil
|
||||
}
|
||||
|
||||
func (e *AzureResourceGraphDatasource) unmarshalResponse(res *http.Response) (AzureResourceGraphResponse, error) {
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
|
@@ -1,7 +1,9 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -71,3 +73,43 @@ func TestBuildingAzureResourceGraphQueries(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureResourceGraphCreateRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsInfo := datasourceInfo{
|
||||
Services: map[string]datasourceService{
|
||||
azureResourceGraph: {URL: "http://ds"},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedURL string
|
||||
expectedHeaders http.Header
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates a request",
|
||||
expectedURL: "http://ds/",
|
||||
expectedHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"User-Agent": []string{"Grafana/"},
|
||||
},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ds := AzureResourceGraphDatasource{}
|
||||
req, err := ds.createRequest(ctx, dsInfo, []byte{})
|
||||
tt.Err(t, err)
|
||||
if req.URL.String() != tt.expectedURL {
|
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
|
||||
}
|
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) {
|
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -3,7 +3,6 @@ package azuremonitor
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -15,11 +14,6 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/grafana/grafana/pkg/api/pluginproxy"
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
@@ -27,10 +21,7 @@ import (
|
||||
)
|
||||
|
||||
// AzureMonitorDatasource calls the Azure Monitor API - one of the four API's supported
|
||||
type AzureMonitorDatasource struct {
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
type AzureMonitorDatasource struct{}
|
||||
|
||||
var (
|
||||
// 1m, 5m, 15m, 30m, 1h, 6h, 12h, 1d in milliseconds
|
||||
@@ -189,7 +180,7 @@ func (e *AzureMonitorDatasource) executeQuery(ctx context.Context, query *AzureM
|
||||
|
||||
azlog.Debug("AzureMonitor", "Request ApiURL", req.URL.String())
|
||||
azlog.Debug("AzureMonitor", "Target", query.Target)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureMonitor].HTTPClient, req)
|
||||
if err != nil {
|
||||
dataResponse.Error = err
|
||||
return dataResponse, AzureMonitorResponse{}, nil
|
||||
@@ -210,63 +201,17 @@ func (e *AzureMonitorDatasource) executeQuery(ctx context.Context, query *AzureM
|
||||
}
|
||||
|
||||
func (e *AzureMonitorDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
|
||||
// find plugin
|
||||
plugin := e.pluginManager.GetDataSource(dsName)
|
||||
if plugin == nil {
|
||||
return nil, errors.New("unable to find datasource plugin Azure Monitor")
|
||||
}
|
||||
|
||||
azureMonitorRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(dsInfo.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "render")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[azureMonitor].URL, nil)
|
||||
if err != nil {
|
||||
azlog.Debug("Failed to create request", "error", err)
|
||||
return nil, errutil.Wrap("Failed to create request", err)
|
||||
}
|
||||
|
||||
req.URL.Path = "/subscriptions"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// TODO: Use backend authentication instead
|
||||
proxyPass := fmt.Sprintf("%s/subscriptions", routeName)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, azureMonitorRoute, &models.DataSource{
|
||||
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
|
||||
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
|
||||
}, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (e *AzureMonitorDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
|
||||
cloud, err := getAzureCloud(e.cfg, dsInfo)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
routeName, err := getManagementApiRoute(cloud)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var pluginRoute *plugins.AppPluginRoute
|
||||
for _, route := range plugin.Routes {
|
||||
if route.Path == routeName {
|
||||
pluginRoute = route
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return pluginRoute, routeName, nil
|
||||
}
|
||||
|
||||
func (e *AzureMonitorDatasource) unmarshalResponse(res *http.Response) (AzureMonitorResponse, error) {
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
|
@@ -1,9 +1,11 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -509,3 +511,42 @@ func loadTestFile(t *testing.T, name string) AzureMonitorResponse {
|
||||
require.NoError(t, err)
|
||||
return azData
|
||||
}
|
||||
|
||||
func TestAzureMonitorCreateRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsInfo := datasourceInfo{
|
||||
Services: map[string]datasourceService{
|
||||
azureMonitor: {URL: "http://ds"},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedURL string
|
||||
expectedHeaders http.Header
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates a request",
|
||||
expectedURL: "http://ds/subscriptions",
|
||||
expectedHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ds := AzureMonitorDatasource{}
|
||||
req, err := ds.createRequest(ctx, dsInfo)
|
||||
tt.Err(t, err)
|
||||
if req.URL.String() != tt.expectedURL {
|
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
|
||||
}
|
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) {
|
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
||||
@@ -39,7 +39,6 @@ func init() {
|
||||
|
||||
type Service struct {
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
HTTPClientProvider httpclient.Provider `inject:""`
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
BackendPluginManager backendplugin.Manager `inject:""`
|
||||
}
|
||||
@@ -59,30 +58,26 @@ type azureMonitorSettings struct {
|
||||
}
|
||||
|
||||
type datasourceInfo struct {
|
||||
Settings azureMonitorSettings
|
||||
Settings azureMonitorSettings
|
||||
Services map[string]datasourceService
|
||||
Routes map[string]azRoute
|
||||
HTTPCliOpts httpclient.Options
|
||||
|
||||
HTTPClient *http.Client
|
||||
URL string
|
||||
JSONData map[string]interface{}
|
||||
DecryptedSecureJSONData map[string]string
|
||||
DatasourceID int64
|
||||
OrgID int64
|
||||
}
|
||||
|
||||
func NewInstanceSettings(httpClientProvider httpclient.Provider) datasource.InstanceFactoryFunc {
|
||||
type datasourceService struct {
|
||||
URL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewInstanceSettings() datasource.InstanceFactoryFunc {
|
||||
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
|
||||
opts, err := settings.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := httpClientProvider.New(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonData := map[string]interface{}{}
|
||||
err = json.Unmarshal(settings.JSONData, &jsonData)
|
||||
err := json.Unmarshal(settings.JSONData, &jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading settings: %w", err)
|
||||
}
|
||||
@@ -92,15 +87,20 @@ func NewInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading settings: %w", err)
|
||||
}
|
||||
httpCliOpts, err := settings.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting http options: %w", err)
|
||||
}
|
||||
|
||||
model := datasourceInfo{
|
||||
Settings: azMonitorSettings,
|
||||
HTTPClient: client,
|
||||
URL: settings.URL,
|
||||
JSONData: jsonData,
|
||||
DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
|
||||
DatasourceID: settings.ID,
|
||||
Services: map[string]datasourceService{},
|
||||
Routes: routes[azMonitorSettings.CloudName],
|
||||
HTTPCliOpts: httpCliOpts,
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
@@ -109,15 +109,8 @@ type azDatasourceExecutor interface {
|
||||
executeTimeSeriesQuery(ctx context.Context, originalQueries []backend.DataQuery, dsInfo datasourceInfo) (*backend.QueryDataResponse, error)
|
||||
}
|
||||
|
||||
func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC httpclient.Provider, cfg *setting.Cfg) *datasource.QueryTypeMux {
|
||||
func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors map[string]azDatasourceExecutor) *datasource.QueryTypeMux {
|
||||
mux := datasource.NewQueryTypeMux()
|
||||
executors := map[string]azDatasourceExecutor{
|
||||
"Azure Monitor": &AzureMonitorDatasource{pm, cfg},
|
||||
"Application Insights": &ApplicationInsightsDatasource{pm, cfg},
|
||||
"Azure Log Analytics": &AzureLogAnalyticsDatasource{pm, cfg},
|
||||
"Insights Analytics": &InsightsAnalyticsDatasource{pm, cfg},
|
||||
"Azure Resource Graph": &AzureResourceGraphDatasource{pm, cfg},
|
||||
}
|
||||
for dsType := range executors {
|
||||
// Make a copy of the string to keep the reference after the iterator
|
||||
dst := dsType
|
||||
@@ -129,6 +122,18 @@ func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC http
|
||||
dsInfo := i.(datasourceInfo)
|
||||
dsInfo.OrgID = req.PluginContext.OrgID
|
||||
ds := executors[dst]
|
||||
if _, ok := dsInfo.Services[dst]; !ok {
|
||||
// Create an HTTP Client if it has not been created before
|
||||
route := dsInfo.Routes[dst]
|
||||
client, err := newHTTPClient(ctx, route, dsInfo, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dsInfo.Services[dst] = datasourceService{
|
||||
URL: dsInfo.Routes[dst].URL,
|
||||
HTTPClient: client,
|
||||
}
|
||||
}
|
||||
return ds.executeTimeSeriesQuery(ctx, req.Queries, dsInfo)
|
||||
})
|
||||
}
|
||||
@@ -136,9 +141,16 @@ func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC http
|
||||
}
|
||||
|
||||
func (s *Service) Init() error {
|
||||
im := datasource.NewInstanceManager(NewInstanceSettings(s.HTTPClientProvider))
|
||||
im := datasource.NewInstanceManager(NewInstanceSettings())
|
||||
executors := map[string]azDatasourceExecutor{
|
||||
azureMonitor: &AzureMonitorDatasource{},
|
||||
appInsights: &ApplicationInsightsDatasource{},
|
||||
azureLogAnalytics: &AzureLogAnalyticsDatasource{},
|
||||
insightsAnalytics: &InsightsAnalyticsDatasource{},
|
||||
azureResourceGraph: &AzureResourceGraphDatasource{},
|
||||
}
|
||||
factory := coreplugin.New(backend.ServeOpts{
|
||||
QueryDataHandler: newExecutor(im, s.PluginManager, s.HTTPClientProvider, s.Cfg),
|
||||
QueryDataHandler: newExecutor(im, s.Cfg, executors),
|
||||
})
|
||||
|
||||
if err := s.BackendPluginManager.Register(dsName, factory); err != nil {
|
||||
|
127
pkg/tsdb/azuremonitor/azuremonitor_test.go
Normal file
127
pkg/tsdb/azuremonitor/azuremonitor_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewInstanceSettings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
settings backend.DataSourceInstanceSettings
|
||||
expectedModel datasourceInfo
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates an instance",
|
||||
settings: backend.DataSourceInstanceSettings{
|
||||
JSONData: []byte(`{"cloudName":"azuremonitor"}`),
|
||||
DecryptedSecureJSONData: map[string]string{"key": "value"},
|
||||
ID: 40,
|
||||
},
|
||||
expectedModel: datasourceInfo{
|
||||
Settings: azureMonitorSettings{CloudName: "azuremonitor"},
|
||||
Routes: routes["azuremonitor"],
|
||||
JSONData: map[string]interface{}{"cloudName": string("azuremonitor")},
|
||||
DatasourceID: 40,
|
||||
DecryptedSecureJSONData: map[string]string{"key": "value"},
|
||||
},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewInstanceSettings()
|
||||
instance, err := factory(tt.settings)
|
||||
tt.Err(t, err)
|
||||
if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) {
|
||||
t.Errorf("Unexpected instance: %v", cmp.Diff(instance, tt.expectedModel))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeInstance struct{}
|
||||
|
||||
func (f *fakeInstance) Get(pluginContext backend.PluginContext) (instancemgmt.Instance, error) {
|
||||
return datasourceInfo{
|
||||
Services: map[string]datasourceService{},
|
||||
Routes: routes[azureMonitorPublic],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeInstance) Do(pluginContext backend.PluginContext, fn instancemgmt.InstanceCallbackFunc) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeExecutor struct {
|
||||
t *testing.T
|
||||
queryType string
|
||||
expectedURL string
|
||||
}
|
||||
|
||||
func (f *fakeExecutor) executeTimeSeriesQuery(ctx context.Context, originalQueries []backend.DataQuery, dsInfo datasourceInfo) (*backend.QueryDataResponse, error) {
|
||||
if s, ok := dsInfo.Services[f.queryType]; !ok {
|
||||
f.t.Errorf("The HTTP client for %s is missing", f.queryType)
|
||||
} else {
|
||||
if s.URL != f.expectedURL {
|
||||
f.t.Errorf("Unexpected URL %s wanted %s", s.URL, f.expectedURL)
|
||||
}
|
||||
}
|
||||
return &backend.QueryDataResponse{}, nil
|
||||
}
|
||||
|
||||
func Test_newExecutor(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryType string
|
||||
expectedURL string
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates an Azure Monitor executor",
|
||||
queryType: azureMonitor,
|
||||
expectedURL: routes[azureMonitorPublic][azureMonitor].URL,
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "creates an Azure Log Analytics executor",
|
||||
queryType: azureLogAnalytics,
|
||||
expectedURL: routes[azureMonitorPublic][azureLogAnalytics].URL,
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mux := newExecutor(&fakeInstance{}, cfg, map[string]azDatasourceExecutor{
|
||||
tt.queryType: &fakeExecutor{
|
||||
t: t,
|
||||
queryType: tt.queryType,
|
||||
expectedURL: tt.expectedURL,
|
||||
},
|
||||
})
|
||||
res, err := mux.QueryData(context.TODO(), &backend.QueryDataRequest{
|
||||
PluginContext: backend.PluginContext{},
|
||||
Queries: []backend.DataQuery{
|
||||
{QueryType: tt.queryType},
|
||||
},
|
||||
})
|
||||
tt.Err(t, err)
|
||||
// Dummy response from the fake implementation
|
||||
if res == nil {
|
||||
t.Errorf("Expecting a response")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,14 +1,13 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
const (
|
||||
AzureAuthManagedIdentity = "msi"
|
||||
AzureAuthClientSecret = "clientsecret"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
|
||||
)
|
||||
|
||||
// Azure cloud names specific to Azure Monitor
|
||||
@@ -19,59 +18,40 @@ const (
|
||||
azureMonitorGermany = "germanyazuremonitor"
|
||||
)
|
||||
|
||||
func getAuthType(cfg *setting.Cfg, dsInfo datasourceInfo) string {
|
||||
if dsInfo.Settings.AzureAuthType != "" {
|
||||
return dsInfo.Settings.AzureAuthType
|
||||
// Azure cloud query types
|
||||
const (
|
||||
azureMonitor = "Azure Monitor"
|
||||
appInsights = "Application Insights"
|
||||
azureLogAnalytics = "Azure Log Analytics"
|
||||
insightsAnalytics = "Insights Analytics"
|
||||
azureResourceGraph = "Azure Resource Graph"
|
||||
)
|
||||
|
||||
func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
|
||||
if len(route.Scopes) > 0 {
|
||||
tokenAuth := &plugins.JwtTokenAuth{
|
||||
Url: route.URL,
|
||||
Scopes: route.Scopes,
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": model.Settings.AzureAuthType,
|
||||
"azure_cloud": cfg.Azure.Cloud,
|
||||
"tenant_id": model.Settings.TenantId,
|
||||
"client_id": model.Settings.ClientId,
|
||||
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
|
||||
},
|
||||
}
|
||||
tokenProvider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
return httpclient.NewProvider(httpclient.ProviderOptions{
|
||||
Middlewares: []httpclient.Middleware{
|
||||
tokenprovider.AuthMiddleware(tokenProvider),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
tenantId := dsInfo.Settings.TenantId
|
||||
clientId := dsInfo.Settings.ClientId
|
||||
|
||||
// If authentication type isn't explicitly specified and datasource has client credentials,
|
||||
// then this is existing datasource which is configured for app registration (client secret)
|
||||
if tenantId != "" && clientId != "" {
|
||||
return AzureAuthClientSecret
|
||||
}
|
||||
|
||||
// For newly created datasource with no configuration, managed identity is the default authentication type
|
||||
// if they are enabled in Grafana config
|
||||
if cfg.Azure.ManagedIdentityEnabled {
|
||||
return AzureAuthManagedIdentity
|
||||
} else {
|
||||
return AzureAuthClientSecret
|
||||
}
|
||||
return httpclient.NewProvider()
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultAzureCloud(cfg *setting.Cfg) (string, error) {
|
||||
switch cfg.Azure.Cloud {
|
||||
case setting.AzurePublic:
|
||||
return azureMonitorPublic, nil
|
||||
case setting.AzureChina:
|
||||
return azureMonitorChina, nil
|
||||
case setting.AzureUSGovernment:
|
||||
return azureMonitorUSGovernment, nil
|
||||
case setting.AzureGermany:
|
||||
return azureMonitorGermany, nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", cfg.Azure.Cloud)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureCloud(cfg *setting.Cfg, dsInfo datasourceInfo) (string, error) {
|
||||
authType := getAuthType(cfg, dsInfo)
|
||||
switch authType {
|
||||
case AzureAuthManagedIdentity:
|
||||
// In case of managed identity, the cloud is always same as where Grafana is hosted
|
||||
return getDefaultAzureCloud(cfg)
|
||||
case AzureAuthClientSecret:
|
||||
if dsInfo.Settings.CloudName != "" {
|
||||
return dsInfo.Settings.CloudName, nil
|
||||
} else {
|
||||
return getDefaultAzureCloud(cfg)
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("the authentication type '%s' not supported", authType)
|
||||
return "", err
|
||||
}
|
||||
func newHTTPClient(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
|
||||
model.HTTPCliOpts.Headers = route.Headers
|
||||
return httpClientProvider(ctx, route, model, cfg).New(model.HTTPCliOpts)
|
||||
}
|
||||
|
54
pkg/tsdb/azuremonitor/credentials_test.go
Normal file
54
pkg/tsdb/azuremonitor/credentials_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_httpCliProvider(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
cfg := &setting.Cfg{}
|
||||
model := datasourceInfo{
|
||||
Settings: azureMonitorSettings{},
|
||||
DecryptedSecureJSONData: map[string]string{"clientSecret": "content"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
route azRoute
|
||||
expectedMiddlewares int
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates an HTTP client with a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{"http://route/.default"},
|
||||
},
|
||||
expectedMiddlewares: 1,
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "creates an HTTP client without a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{},
|
||||
},
|
||||
// httpclient.NewProvider returns a client with 2 middlewares by default
|
||||
expectedMiddlewares: 2,
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := httpClientProvider(ctx, tt.route, model, cfg)
|
||||
// Cannot test that the cli middleware works properly since the azcore sdk
|
||||
// rejects the TLS certs (if provided)
|
||||
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
|
||||
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -13,21 +12,12 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/grafana/grafana/pkg/api/pluginproxy"
|
||||
"github.com/grafana/grafana/pkg/components/securejsondata"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
)
|
||||
|
||||
type InsightsAnalyticsDatasource struct {
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
type InsightsAnalyticsDatasource struct{}
|
||||
|
||||
type InsightsAnalyticsQuery struct {
|
||||
RefID string
|
||||
@@ -122,7 +112,7 @@ func (e *InsightsAnalyticsDatasource) executeQuery(ctx context.Context, query *I
|
||||
}
|
||||
|
||||
azlog.Debug("ApplicationInsights", "Request URL", req.URL.String())
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
|
||||
res, err := ctxhttp.Do(ctx, dsInfo.Services[appInsights].HTTPClient, req)
|
||||
if err != nil {
|
||||
return dataResponseError(err)
|
||||
}
|
||||
@@ -179,59 +169,14 @@ func (e *InsightsAnalyticsDatasource) executeQuery(ctx context.Context, query *I
|
||||
}
|
||||
|
||||
func (e *InsightsAnalyticsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
|
||||
// find plugin
|
||||
plugin := e.pluginManager.GetDataSource(dsName)
|
||||
if plugin == nil {
|
||||
return nil, errors.New("unable to find datasource plugin Azure Application Insights")
|
||||
}
|
||||
|
||||
appInsightsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appInsightsAppID := dsInfo.Settings.AppInsightsAppId
|
||||
|
||||
u, err := url.Parse(dsInfo.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse url for Application Insights Analytics datasource: %w", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, fmt.Sprintf("/v1/apps/%s", appInsightsAppID))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[insightsAnalytics].URL, nil)
|
||||
if err != nil {
|
||||
azlog.Debug("Failed to create request", "error", err)
|
||||
return nil, errutil.Wrap("Failed to create request", err)
|
||||
}
|
||||
|
||||
// TODO: Use backend authentication instead
|
||||
proxyPass := fmt.Sprintf("%s/v1/apps/%s", routeName, appInsightsAppID)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, &models.DataSource{
|
||||
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
|
||||
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
|
||||
}, e.cfg)
|
||||
|
||||
req.Header.Set("X-API-Key", dsInfo.DecryptedSecureJSONData["appInsightsApiKey"])
|
||||
req.URL.Path = fmt.Sprintf("/v1/apps/%s", appInsightsAppID)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (e *InsightsAnalyticsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
|
||||
cloud, err := getAzureCloud(e.cfg, dsInfo)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
routeName, err := getAppInsightsApiRoute(cloud)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var pluginRoute *plugins.AppPluginRoute
|
||||
for _, route := range plugin.Routes {
|
||||
if route.Path == routeName {
|
||||
pluginRoute = route
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return pluginRoute, routeName, nil
|
||||
}
|
||||
|
53
pkg/tsdb/azuremonitor/insights-analytics-datasource_test.go
Normal file
53
pkg/tsdb/azuremonitor/insights-analytics-datasource_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInsightsAnalyticsCreateRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsInfo := datasourceInfo{
|
||||
Settings: azureMonitorSettings{AppInsightsAppId: "foo"},
|
||||
Services: map[string]datasourceService{
|
||||
insightsAnalytics: {URL: "http://ds"},
|
||||
},
|
||||
DecryptedSecureJSONData: map[string]string{
|
||||
"appInsightsApiKey": "key",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedURL string
|
||||
expectedHeaders http.Header
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates a request",
|
||||
expectedURL: "http://ds/v1/apps/foo",
|
||||
expectedHeaders: http.Header{
|
||||
"X-Api-Key": []string{"key"},
|
||||
},
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ds := InsightsAnalyticsDatasource{}
|
||||
req, err := ds.createRequest(ctx, dsInfo)
|
||||
tt.Err(t, err)
|
||||
if req.URL.String() != tt.expectedURL {
|
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
|
||||
}
|
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) {
|
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,45 +1,90 @@
|
||||
package azuremonitor
|
||||
|
||||
import "fmt"
|
||||
|
||||
func getManagementApiRoute(azureCloud string) (string, error) {
|
||||
switch azureCloud {
|
||||
case azureMonitorPublic:
|
||||
return "azuremonitor", nil
|
||||
case azureMonitorChina:
|
||||
return "chinaazuremonitor", nil
|
||||
case azureMonitorUSGovernment:
|
||||
return "govazuremonitor", nil
|
||||
case azureMonitorGermany:
|
||||
return "germanyazuremonitor", nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
|
||||
return "", err
|
||||
}
|
||||
type azRoute struct {
|
||||
URL string
|
||||
Scopes []string
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
func getLogAnalyticsApiRoute(azureCloud string) (string, error) {
|
||||
switch azureCloud {
|
||||
case azureMonitorPublic:
|
||||
return "loganalyticsazure", nil
|
||||
case azureMonitorChina:
|
||||
return "chinaloganalyticsazure", nil
|
||||
case azureMonitorUSGovernment:
|
||||
return "govloganalyticsazure", nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
|
||||
return "", err
|
||||
}
|
||||
var azManagement = azRoute{
|
||||
URL: "https://management.azure.com",
|
||||
Scopes: []string{"https://management.azure.com/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
func getAppInsightsApiRoute(azureCloud string) (string, error) {
|
||||
switch azureCloud {
|
||||
case azureMonitorPublic:
|
||||
return "appinsights", nil
|
||||
case azureMonitorChina:
|
||||
return "chinaappinsights", nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
|
||||
return "", err
|
||||
}
|
||||
var azUSGovManagement = azRoute{
|
||||
URL: "https://management.usgovcloudapi.net",
|
||||
Scopes: []string{"https://management.usgovcloudapi.net/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
var azGermanyManagement = azRoute{
|
||||
URL: "https://management.microsoftazure.de",
|
||||
Scopes: []string{"https://management.microsoftazure.de/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
var azChinaManagement = azRoute{
|
||||
URL: "https://management.chinacloudapi.cn",
|
||||
Scopes: []string{"https://management.chinacloudapi.cn/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
var azAppInsights = azRoute{
|
||||
URL: "https://api.applicationinsights.io",
|
||||
Scopes: []string{},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
var azChinaAppInsights = azRoute{
|
||||
URL: "https://api.applicationinsights.azure.cn",
|
||||
Scopes: []string{},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana"},
|
||||
}
|
||||
|
||||
var azLogAnalytics = azRoute{
|
||||
URL: "https://api.loganalytics.io",
|
||||
Scopes: []string{"https://api.loganalytics.io/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
|
||||
}
|
||||
|
||||
var azChinaLogAnalytics = azRoute{
|
||||
URL: "https://api.loganalytics.azure.cn",
|
||||
Scopes: []string{"https://api.loganalytics.azure.cn/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
|
||||
}
|
||||
|
||||
var azUSGovLogAnalytics = azRoute{
|
||||
URL: "https://api.loganalytics.us",
|
||||
Scopes: []string{"https://api.loganalytics.us/.default"},
|
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
|
||||
}
|
||||
|
||||
var (
|
||||
// The different Azure routes are identified by its cloud (e.g. public or gov)
|
||||
// and the service to query (e.g. Azure Monitor or Azure Log Analytics)
|
||||
routes = map[string]map[string]azRoute{
|
||||
azureMonitorPublic: {
|
||||
azureMonitor: azManagement,
|
||||
azureLogAnalytics: azLogAnalytics,
|
||||
azureResourceGraph: azManagement,
|
||||
appInsights: azAppInsights,
|
||||
insightsAnalytics: azAppInsights,
|
||||
},
|
||||
azureMonitorUSGovernment: {
|
||||
azureMonitor: azUSGovManagement,
|
||||
azureLogAnalytics: azUSGovLogAnalytics,
|
||||
azureResourceGraph: azUSGovManagement,
|
||||
},
|
||||
azureMonitorGermany: {
|
||||
azureMonitor: azGermanyManagement,
|
||||
},
|
||||
azureMonitorChina: {
|
||||
azureMonitor: azChinaManagement,
|
||||
azureLogAnalytics: azChinaLogAnalytics,
|
||||
azureResourceGraph: azChinaManagement,
|
||||
appInsights: azChinaAppInsights,
|
||||
insightsAnalytics: azChinaAppInsights,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@@ -0,0 +1,33 @@
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type TokenProvider interface {
|
||||
GetAccessToken() (string, error)
|
||||
}
|
||||
|
||||
const authenticationMiddlewareName = "AzureAuthentication"
|
||||
|
||||
func AuthMiddleware(tokenProvider TokenProvider) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve azure access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
})
|
||||
}
|
184
pkg/tsdb/azuremonitor/tokenprovider/token_cache.go
Normal file
184
pkg/tsdb/azuremonitor/tokenprovider/token_cache.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenCredential interface {
|
||||
GetCacheKey() string
|
||||
Init() error
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
type ConcurrentTokenCache interface {
|
||||
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache {
|
||||
return &tokenCacheImpl{}
|
||||
}
|
||||
|
||||
type tokenCacheImpl struct {
|
||||
cache sync.Map // of *credentialCacheEntry
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
credential TokenCredential
|
||||
|
||||
credInit uint32
|
||||
credMutex sync.Mutex
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
credential TokenCredential
|
||||
scopes []string
|
||||
|
||||
cond *sync.Cond
|
||||
refreshing bool
|
||||
accessToken *AccessToken
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := credential.GetCacheKey()
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
|
||||
credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*credentialCacheEntry)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
err := c.ensureInitialized()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.getEntryFor(scopes).getAccessToken(ctx)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) ensureInitialized() error {
|
||||
if atomic.LoadUint32(&c.credInit) == 0 {
|
||||
c.credMutex.Lock()
|
||||
defer c.credMutex.Unlock()
|
||||
|
||||
if c.credInit == 0 {
|
||||
// Initialize credential
|
||||
err := c.credential.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.credInit, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := getKeyForScopes(scopes)
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
|
||||
credential: c.credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*scopesCacheEntry)
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
var accessToken *AccessToken
|
||||
var err error
|
||||
shouldRefresh := false
|
||||
|
||||
c.cond.L.Lock()
|
||||
for {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
|
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken
|
||||
break
|
||||
}
|
||||
|
||||
if !c.refreshing {
|
||||
// Start refreshing the token
|
||||
c.refreshing = true
|
||||
shouldRefresh = true
|
||||
break
|
||||
}
|
||||
|
||||
// Wait for the token to be refreshed
|
||||
c.cond.Wait()
|
||||
}
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if shouldRefresh {
|
||||
accessToken, err = c.refreshAccessToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken.Token, nil
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
|
||||
var accessToken *AccessToken
|
||||
|
||||
// Safeguarding from panic caused by credential implementation
|
||||
defer func() {
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
|
||||
if accessToken != nil {
|
||||
c.accessToken = accessToken
|
||||
}
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
}()
|
||||
|
||||
token, err := c.credential.GetAccessToken(ctx, c.scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken = token
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func getKeyForScopes(scopes []string) string {
|
||||
if len(scopes) > 1 {
|
||||
arr := make([]string, len(scopes))
|
||||
copy(arr, scopes)
|
||||
sort.Strings(arr)
|
||||
scopes = arr
|
||||
}
|
||||
|
||||
return strings.Join(scopes, " ")
|
||||
}
|
457
pkg/tsdb/azuremonitor/tokenprovider/token_cache_test.go
Normal file
457
pkg/tsdb/azuremonitor/tokenprovider/token_cache_test.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeCredential struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
initFunc func() error
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
return c.getAccessTokenFunc(ctx, scopes)
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
}
|
||||
|
||||
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes1 := []string{"Scope1"}
|
||||
scopes2 := []string{"Scope2"}
|
||||
|
||||
t.Run("should request access token from credential", func(t *testing.T) {
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token, err := cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token)
|
||||
|
||||
assert.Equal(t, 1, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same credentials", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential1 := &fakeCredential{key: "credential-1"}
|
||||
credential2 := &fakeCredential{key: "credential-2"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
assert.Equal(t, 1, credential1.calledTimes)
|
||||
assert.Equal(t, 1, credential2.calledTimes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when credential init returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
return errors.New("unable to initialize")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to initialize"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", accessToken)
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to get access token"))
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package pluginproxy
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
@@ -18,27 +17,21 @@ var (
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
datasourceId int64
|
||||
datasourceVersion int
|
||||
ctx context.Context
|
||||
cfg *setting.Cfg
|
||||
route *plugins.AppPluginRoute
|
||||
authParams *plugins.JwtTokenAuth
|
||||
ctx context.Context
|
||||
cfg *setting.Cfg
|
||||
authParams *plugins.JwtTokenAuth
|
||||
}
|
||||
|
||||
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
func NewAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg,
|
||||
authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
|
||||
return &azureAccessTokenProvider{
|
||||
datasourceId: ds.Id,
|
||||
datasourceVersion: ds.Version,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
route: pluginRoute,
|
||||
authParams: authParams,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
authParams: authParams,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getAccessToken() (string, error) {
|
||||
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
var credential TokenCredential
|
||||
|
||||
if provider.isManagedIdentityCredential() {
|
@@ -1,10 +1,9 @@
|
||||
package pluginproxy
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -15,7 +14,7 @@ var getAccessTokenFunc func(credential TokenCredential, scopes []string)
|
||||
|
||||
type tokenCacheFake struct{}
|
||||
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
func (c *tokenCacheFake) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
getAccessTokenFunc(credential, scopes)
|
||||
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
|
||||
}
|
||||
@@ -25,9 +24,6 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
@@ -41,7 +37,7 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
@@ -114,9 +110,6 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
@@ -130,7 +123,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
|
||||
original := azureTokenCache
|
||||
azureTokenCache = &tokenCacheFake{}
|
||||
@@ -148,7 +141,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.IsType(t, &managedIdentityCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
_, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -161,7 +154,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.IsType(t, &clientSecretCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
_, err := provider.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
@@ -178,7 +171,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.Fail(t, "token cache not expected to be called")
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
_, err := provider.GetAccessToken()
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
@@ -189,9 +182,6 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
@@ -205,7 +195,7 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
|
||||
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
|
||||
result := provider.getClientSecretCredential()
|
Reference in New Issue
Block a user