mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Azure token provider with support for Managed Identities (#33807)
* Azure token provider * Configuration for Azure token provider * Authentication via Azure SDK for Go * Fix typo * ConcurrentTokenCache for Azure credentials * Resolve AAD authority for selected Azure cloud * Fixes * Generic AccessToken and fixes * Tests and wordings * Tests for getAccessToken * Tests for getClientSecretCredential * Tests for token cache
This commit is contained in:
@@ -9,12 +9,13 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
// ApplyRoute should use the plugin route data to set auth headers and custom headers.
|
||||
func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route *plugins.AppPluginRoute,
|
||||
ds *models.DataSource) {
|
||||
ds *models.DataSource, cfg *setting.Cfg) {
|
||||
proxyPath = strings.TrimPrefix(proxyPath, route.Path)
|
||||
|
||||
data := templateData{
|
||||
@@ -53,7 +54,7 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
|
||||
logger.Error("Failed to set plugin route body content", "error", err)
|
||||
}
|
||||
|
||||
if tokenProvider, err := getTokenProvider(ctx, ds, route, data); err != nil {
|
||||
if tokenProvider, err := getTokenProvider(ctx, cfg, ds, route, data); err != nil {
|
||||
logger.Error("Failed to resolve auth token provider", "error", err)
|
||||
} else if tokenProvider != nil {
|
||||
if token, err := tokenProvider.getAccessToken(); err != nil {
|
||||
@@ -66,7 +67,7 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
|
||||
logger.Info("Requesting", "url", req.URL.String())
|
||||
}
|
||||
|
||||
func getTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
data templateData) (accessTokenProvider, error) {
|
||||
authType := pluginRoute.AuthType
|
||||
|
||||
@@ -85,6 +86,13 @@ func getTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *p
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "azure":
|
||||
if tokenAuth == nil {
|
||||
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
|
||||
}
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, pluginRoute, tokenAuth)
|
||||
return provider, nil
|
||||
|
||||
case "gce":
|
||||
if jwtTokenAuth == nil {
|
||||
return nil, fmt.Errorf("'jwtTokenAuth' not configured for authentication type '%s'", authType)
|
||||
|
||||
@@ -231,7 +231,7 @@ func (proxy *DataSourceProxy) director(req *http.Request) {
|
||||
req.Header.Del("Referer")
|
||||
|
||||
if proxy.route != nil {
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, proxy.cfg)
|
||||
}
|
||||
|
||||
if oauthtoken.IsOAuthPassThruEnabled(proxy.ds) {
|
||||
|
||||
@@ -109,12 +109,14 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
return ctx, req
|
||||
}
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
t.Run("When matching route path", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg)
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[0]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
|
||||
assert.Equal(t, "https://www.google.com/some/method", req.URL.String())
|
||||
assert.Equal(t, "my secret 123", req.Header.Get("x-header"))
|
||||
@@ -122,10 +124,10 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path and has dynamic url", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/common/some/method", cfg)
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[3]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
|
||||
assert.Equal(t, "https://dynamic.grafana.com/some/method?apiKey=123", req.URL.String())
|
||||
assert.Equal(t, "my secret 123", req.Header.Get("x-header"))
|
||||
@@ -133,20 +135,20 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("When matching route path with no url", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "", cfg)
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[4]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
|
||||
assert.Equal(t, "http://localhost/asd", req.URL.String())
|
||||
})
|
||||
|
||||
t.Run("When matching route path and has dynamic body", func(t *testing.T) {
|
||||
ctx, req := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/body", cfg)
|
||||
require.NoError(t, err)
|
||||
proxy.route = plugin.Routes[5]
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds, cfg)
|
||||
|
||||
content, err := ioutil.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
@@ -156,7 +158,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
t.Run("Validating request", func(t *testing.T) {
|
||||
t.Run("plugin route with valid role", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/v4/some/method", cfg)
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.NoError(t, err)
|
||||
@@ -164,7 +166,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
|
||||
t.Run("plugin route with admin role and user is editor", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg)
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.Error(t, err)
|
||||
@@ -173,7 +175,7 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
t.Run("plugin route with admin role and user is admin", func(t *testing.T) {
|
||||
ctx, _ := setUp()
|
||||
ctx.SignedInUser.OrgRole = models.ROLE_ADMIN
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "api/admin", cfg)
|
||||
require.NoError(t, err)
|
||||
err = proxy.validateRequest()
|
||||
require.NoError(t, err)
|
||||
@@ -253,9 +255,11 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
client = newFakeHTTPClient(t, json)
|
||||
defer func() { client = originalClient }()
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", &setting.Cfg{})
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg)
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
|
||||
|
||||
authorizationHeaderCall1 = req.Header.Get("Authorization")
|
||||
assert.Equal(t, "https://api.nr1.io/some/path", req.URL.String())
|
||||
@@ -268,9 +272,9 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://localhost/asd", nil)
|
||||
require.NoError(t, err)
|
||||
client = newFakeHTTPClient(t, json2)
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken2", cfg)
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[1], proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[1], proxy.ds, cfg)
|
||||
|
||||
authorizationHeaderCall2 = req.Header.Get("Authorization")
|
||||
|
||||
@@ -284,9 +288,9 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
client = newFakeHTTPClient(t, []byte{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", &setting.Cfg{})
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", cfg)
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds, cfg)
|
||||
|
||||
authorizationHeaderCall3 := req.Header.Get("Authorization")
|
||||
assert.Equal(t, "https://api.nr1.io/some/path", req.URL.String())
|
||||
|
||||
127
pkg/api/pluginproxy/token_cache.go
Normal file
127
pkg/api/pluginproxy/token_cache.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenCredential interface {
|
||||
GetCacheKey() string
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
type ConcurrentTokenCache interface {
|
||||
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache {
|
||||
return &tokenCacheImpl{}
|
||||
}
|
||||
|
||||
type tokenCacheImpl struct {
|
||||
cache sync.Map // of *credentialCacheEntry
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
credential TokenCredential
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
credential TokenCredential
|
||||
scopes []string
|
||||
|
||||
cond *sync.Cond
|
||||
refreshing bool
|
||||
accessToken *AccessToken
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
credentialKey := credential.GetCacheKey()
|
||||
scopesKey := getKeyForScopes(scopes)
|
||||
|
||||
if entry, ok = c.cache.Load(credentialKey); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(credentialKey, &credentialCacheEntry{
|
||||
credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
credentialEntry := entry.(*credentialCacheEntry)
|
||||
|
||||
if entry, ok = credentialEntry.cache.Load(scopesKey); !ok {
|
||||
entry, _ = credentialEntry.cache.LoadOrStore(scopesKey, &scopesCacheEntry{
|
||||
credential: credentialEntry.credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
scopesEntry := entry.(*scopesCacheEntry)
|
||||
|
||||
return scopesEntry.getAccessToken(ctx)
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
var accessToken *AccessToken
|
||||
var err error
|
||||
shouldRefresh := false
|
||||
|
||||
c.cond.L.Lock()
|
||||
for {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
|
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken
|
||||
break
|
||||
}
|
||||
|
||||
if !c.refreshing {
|
||||
// Start refreshing the token
|
||||
c.refreshing = true
|
||||
shouldRefresh = true
|
||||
break
|
||||
}
|
||||
|
||||
// Wait for the token to be refreshed
|
||||
c.cond.Wait()
|
||||
}
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if shouldRefresh {
|
||||
accessToken, err = c.credential.GetAccessToken(ctx, c.scopes)
|
||||
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
c.accessToken = accessToken
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken.Token, nil
|
||||
}
|
||||
|
||||
func getKeyForScopes(scopes []string) string {
|
||||
if len(scopes) > 1 {
|
||||
arr := make([]string, len(scopes))
|
||||
copy(arr, scopes)
|
||||
sort.Strings(arr)
|
||||
scopes = arr
|
||||
}
|
||||
|
||||
return strings.Join(scopes, " ")
|
||||
}
|
||||
102
pkg/api/pluginproxy/token_cache_test.go
Normal file
102
pkg/api/pluginproxy/token_cache_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeCredential struct {
|
||||
key string
|
||||
calledTimes int
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
return c.getAccessTokenFunc(ctx, scopes)
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
}
|
||||
|
||||
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes1 := []string{"Scope1"}
|
||||
scopes2 := []string{"Scope2"}
|
||||
|
||||
t.Run("should request access token from credential", func(t *testing.T) {
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token, err := cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token)
|
||||
|
||||
assert.Equal(t, 1, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same credentials", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential1 := &fakeCredential{key: "credential-1"}
|
||||
credential2 := &fakeCredential{key: "credential-2"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
assert.Equal(t, 1, credential1.calledTimes)
|
||||
assert.Equal(t, 1, credential2.calledTimes)
|
||||
})
|
||||
}
|
||||
180
pkg/api/pluginproxy/token_provider_azure.go
Normal file
180
pkg/api/pluginproxy/token_provider_azure.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
azureTokenCache = NewConcurrentTokenCache()
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
datasourceId int64
|
||||
datasourceVersion int
|
||||
ctx context.Context
|
||||
cfg *setting.Cfg
|
||||
route *plugins.AppPluginRoute
|
||||
authParams *plugins.JwtTokenAuth
|
||||
}
|
||||
|
||||
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
|
||||
return &azureAccessTokenProvider{
|
||||
datasourceId: ds.Id,
|
||||
datasourceVersion: ds.Version,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
route: pluginRoute,
|
||||
authParams: authParams,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getAccessToken() (string, error) {
|
||||
var credential TokenCredential
|
||||
|
||||
if provider.isManagedIdentityCredential() {
|
||||
if !provider.cfg.Azure.ManagedIdentityEnabled {
|
||||
err := fmt.Errorf("managed identity authentication not enabled in Grafana config")
|
||||
return "", err
|
||||
} else {
|
||||
credential = provider.getManagedIdentityCredential()
|
||||
}
|
||||
} else {
|
||||
credential = provider.getClientSecretCredential()
|
||||
}
|
||||
|
||||
accessToken, err := azureTokenCache.GetAccessToken(provider.ctx, credential, provider.authParams.Scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) isManagedIdentityCredential() bool {
|
||||
authType := strings.ToLower(provider.authParams.Params["azure_auth_type"])
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
|
||||
// Type of authentication being determined by the following logic:
|
||||
// * If authType is set to 'msi' then user explicitly selected the managed identity authentication
|
||||
// * If authType isn't set but other fields are configured then it's a datasource which was configured
|
||||
// before managed identities where introduced, therefore use client secret authentication
|
||||
// * If authType and other fields aren't set then it means the datasource never been configured
|
||||
// and managed identity is the default authentication choice as long as managed identities are enabled
|
||||
return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled)
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getManagedIdentityCredential() TokenCredential {
|
||||
clientId := provider.cfg.Azure.ManagedIdentityClientId
|
||||
|
||||
return &managedIdentityCredential{clientId: clientId}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getClientSecretCredential() TokenCredential {
|
||||
authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"])
|
||||
tenantId := provider.authParams.Params["tenant_id"]
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
clientSecret := provider.authParams.Params["client_secret"]
|
||||
|
||||
return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string) string {
|
||||
// Known Azure clouds
|
||||
switch cloudName {
|
||||
case setting.AzurePublic:
|
||||
return azidentity.AzurePublicCloud
|
||||
case setting.AzureChina:
|
||||
return azidentity.AzureChina
|
||||
case setting.AzureUSGovernment:
|
||||
return azidentity.AzureGovernment
|
||||
case setting.AzureGermany:
|
||||
return azidentity.AzureGermany
|
||||
}
|
||||
// Fallback to direct URL
|
||||
return provider.authParams.Url
|
||||
}
|
||||
|
||||
type managedIdentityCredential struct {
|
||||
clientId string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
clientId := c.clientId
|
||||
if clientId == "" {
|
||||
clientId = "system"
|
||||
}
|
||||
return fmt.Sprintf("azure|msi|%s", clientId)
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
// No need to lock here because the caller is responsible for thread safety
|
||||
if c.credential == nil {
|
||||
var err error
|
||||
c.credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource
|
||||
if len(scopes) == 0 {
|
||||
return nil, errors.New("scopes not provided")
|
||||
}
|
||||
resource := strings.TrimSuffix(scopes[0], "/.default")
|
||||
scopes = []string{resource}
|
||||
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
type clientSecretCredential struct {
|
||||
authority string
|
||||
tenantId string
|
||||
clientId string
|
||||
clientSecret string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
// No need to lock here because the caller is responsible for thread safety
|
||||
if c.credential == nil {
|
||||
var err error
|
||||
c.credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
func hashSecret(secret string) string {
|
||||
hash := sha256.New()
|
||||
_, _ = hash.Write([]byte(secret))
|
||||
return fmt.Sprintf("%x", hash.Sum(nil))
|
||||
}
|
||||
221
pkg/api/pluginproxy/token_provider_azure_test.go
Normal file
221
pkg/api/pluginproxy/token_provider_azure_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var getAccessTokenFunc func(credential TokenCredential, scopes []string)
|
||||
|
||||
type tokenCacheFake struct{}
|
||||
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
getAccessTokenFunc(credential, scopes)
|
||||
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if auth type is client secret", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "clientsecret",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be managed identity if datasource not configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "06da9207-bdd9-4558-aee4-377450893cb4",
|
||||
"client_id": "b8c58fe8-1fca-4e30-a0a8-b44d0e5f70d6",
|
||||
"client_secret": "9bcd4434-824f-4887-a8a8-94c287bf0a7b",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when managed identities disabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if datasource not configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
|
||||
original := azureTokenCache
|
||||
azureTokenCache = &tokenCacheFake{}
|
||||
t.Cleanup(func() { azureTokenCache = original })
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should resolve managed identity credential if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.IsType(t, &managedIdentityCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("should resolve client secret credential if auth type is client secret", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "clientsecret",
|
||||
}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.IsType(t, &clientSecretCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when managed identities disabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should return error if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.Fail(t, "token cache not expected to be called")
|
||||
}
|
||||
|
||||
_, err := provider.getAccessToken()
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
ds := &models.DataSource{Id: 1, Version: 2}
|
||||
route := &plugins.AppPluginRoute{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
|
||||
"client_id": "1af7c188-e5b6-4f96-81b8-911761bdd459",
|
||||
"client_secret": "0416d95e-8af8-472c-aaa3-15c93c46080a",
|
||||
},
|
||||
}
|
||||
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
|
||||
|
||||
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
|
||||
result := provider.getClientSecretCredential()
|
||||
assert.IsType(t, &clientSecretCredential{}, result)
|
||||
|
||||
credential := (result).(*clientSecretCredential)
|
||||
|
||||
assert.Equal(t, "https://login.microsoftonline.com/", credential.authority)
|
||||
assert.Equal(t, "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", credential.tenantId)
|
||||
assert.Equal(t, "1af7c188-e5b6-4f96-81b8-911761bdd459", credential.clientId)
|
||||
assert.Equal(t, "0416d95e-8af8-472c-aaa3-15c93c46080a", credential.clientSecret)
|
||||
})
|
||||
}
|
||||
@@ -29,6 +29,7 @@ type ApplicationInsightsDatasource struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
// ApplicationInsightsQuery is the model that holds the information
|
||||
@@ -243,7 +244,7 @@ func (e *ApplicationInsightsDatasource) createRequest(ctx context.Context, dsInf
|
||||
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
|
||||
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, dsInfo)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, dsInfo, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type AzureLogAnalyticsDatasource struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
// AzureLogAnalyticsQuery is the query request that is built from the saved values for
|
||||
@@ -229,7 +230,7 @@ func (e *AzureLogAnalyticsDatasource) createRequest(ctx context.Context, dsInfo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pluginproxy.ApplyRoute(ctx, req, proxypass, logAnalyticsRoute, dsInfo)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxypass, logAnalyticsRoute, dsInfo, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type AzureMonitorDatasource struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -259,7 +260,7 @@ func (e *AzureMonitorDatasource) createRequest(ctx context.Context, dsInfo *mode
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
|
||||
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, azureMonitorRoute, dsInfo)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, azureMonitorRoute, dsInfo, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,6 +28,7 @@ func init() {
|
||||
|
||||
type Service struct {
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
}
|
||||
|
||||
func (s *Service) Init() error {
|
||||
@@ -38,6 +40,7 @@ type AzureMonitorExecutor struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
// NewAzureMonitorExecutor initializes a http client
|
||||
@@ -52,6 +55,7 @@ func (s *Service) NewExecutor(dsInfo *models.DataSource) (plugins.DataPlugin, er
|
||||
httpClient: httpClient,
|
||||
dsInfo: dsInfo,
|
||||
pluginManager: s.PluginManager,
|
||||
cfg: s.Cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -90,24 +94,28 @@ func (e *AzureMonitorExecutor) DataQuery(ctx context.Context, dsInfo *models.Dat
|
||||
httpClient: e.httpClient,
|
||||
dsInfo: e.dsInfo,
|
||||
pluginManager: e.pluginManager,
|
||||
cfg: e.cfg,
|
||||
}
|
||||
|
||||
aiDatasource := &ApplicationInsightsDatasource{
|
||||
httpClient: e.httpClient,
|
||||
dsInfo: e.dsInfo,
|
||||
pluginManager: e.pluginManager,
|
||||
cfg: e.cfg,
|
||||
}
|
||||
|
||||
alaDatasource := &AzureLogAnalyticsDatasource{
|
||||
httpClient: e.httpClient,
|
||||
dsInfo: e.dsInfo,
|
||||
pluginManager: e.pluginManager,
|
||||
cfg: e.cfg,
|
||||
}
|
||||
|
||||
iaDatasource := &InsightsAnalyticsDatasource{
|
||||
httpClient: e.httpClient,
|
||||
dsInfo: e.dsInfo,
|
||||
pluginManager: e.pluginManager,
|
||||
cfg: e.cfg,
|
||||
}
|
||||
|
||||
azResult, err := azDatasource.executeTimeSeriesQuery(ctx, azureMonitorQueries, *tsdbQuery.TimeRange)
|
||||
|
||||
@@ -25,6 +25,7 @@ type InsightsAnalyticsDatasource struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
type InsightsAnalyticsQuery struct {
|
||||
@@ -217,7 +218,7 @@ func (e *InsightsAnalyticsDatasource) createRequest(ctx context.Context, dsInfo
|
||||
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
|
||||
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, dsInfo)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, dsInfo, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -74,6 +74,7 @@ func init() {
|
||||
|
||||
type Service struct {
|
||||
PluginManager plugins.Manager `inject:""`
|
||||
Cfg *setting.Cfg `inject:""`
|
||||
}
|
||||
|
||||
func (s *Service) Init() error {
|
||||
@@ -85,6 +86,7 @@ type Executor struct {
|
||||
httpClient *http.Client
|
||||
dsInfo *models.DataSource
|
||||
pluginManager plugins.Manager
|
||||
cfg *setting.Cfg
|
||||
}
|
||||
|
||||
// NewExecutor returns an Executor.
|
||||
@@ -99,6 +101,7 @@ func (s *Service) NewExecutor(dsInfo *models.DataSource) (plugins.DataPlugin, er
|
||||
httpClient: httpClient,
|
||||
dsInfo: dsInfo,
|
||||
pluginManager: s.PluginManager,
|
||||
cfg: s.Cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -522,7 +525,7 @@ func (e *Executor) createRequest(ctx context.Context, dsInfo *models.DataSource,
|
||||
}
|
||||
}
|
||||
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, cloudMonitoringRoute, dsInfo)
|
||||
pluginproxy.ApplyRoute(ctx, req, proxyPass, cloudMonitoringRoute, dsInfo, e.cfg)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user