AzureMonitor: strongly-typed AzureCredentials and correct resolution of auth type and cloud (#36284)

This commit is contained in:
Sergey Kostrukov 2021-07-05 03:20:12 -07:00 committed by GitHub
parent 719e78f333
commit 89ba607382
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 749 additions and 404 deletions

View File

@ -92,8 +92,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
if tokenAuth == nil {
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
}
provider := newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
return provider, nil
return newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
case "gce":
if jwtTokenAuth == nil {

View File

@ -2,24 +2,58 @@ package pluginproxy
import (
"context"
"strings"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
)
type azureAccessTokenProvider struct {
ctx context.Context
tokenProvider aztokenprovider.AzureTokenProvider
scopes []string
}
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) (*azureAccessTokenProvider, error) {
credentials := getAzureCredentials(cfg, authParams)
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, credentials)
if err != nil {
return nil, err
}
return &azureAccessTokenProvider{
ctx: ctx,
tokenProvider: aztokenprovider.NewAzureAccessTokenProvider(cfg, authParams),
}
tokenProvider: tokenProvider,
scopes: authParams.Scopes,
}, nil
}
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
return provider.tokenProvider.GetAccessToken(provider.ctx)
return provider.tokenProvider.GetAccessToken(provider.ctx, provider.scopes)
}
func getAzureCredentials(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) azcredentials.AzureCredentials {
authType := strings.ToLower(authParams.Params["azure_auth_type"])
clientId := authParams.Params["client_id"]
// Type of authentication being determined by the following logic:
// * If authType is set to 'msi' then user explicitly selected the managed identity authentication
// * If authType isn't set but other fields are configured then it's a datasource which was configured
// before managed identities where introduced, therefore use client secret authentication
// * If authType and other fields aren't set then it means the datasource never been configured
// and managed identity is the default authentication choice as long as managed identities are enabled
isManagedIdentity := authType == "msi" || (authType == "" && clientId == "" && cfg.Azure.ManagedIdentityEnabled)
if isManagedIdentity {
return &azcredentials.AzureManagedIdentityCredentials{}
} else {
return &azcredentials.AzureClientSecretCredentials{
AzureCloud: authParams.Params["azure_cloud"],
Authority: authParams.Url,
TenantId: authParams.Params["tenant_id"],
ClientId: authParams.Params["client_id"],
ClientSecret: authParams.Params["client_secret"],
}
}
}

View File

@ -0,0 +1,30 @@
package azcredentials
const (
AzureAuthManagedIdentity = "msi"
AzureAuthClientSecret = "clientsecret"
)
type AzureCredentials interface {
AzureAuthType() string
}
type AzureManagedIdentityCredentials struct {
ClientId string
}
type AzureClientSecretCredentials struct {
AzureCloud string
Authority string
TenantId string
ClientId string
ClientSecret string
}
func (credentials *AzureManagedIdentityCredentials) AzureAuthType() string {
return AzureAuthManagedIdentity
}
func (credentials *AzureClientSecretCredentials) AzureAuthType() string {
return AzureAuthClientSecret
}

View File

@ -9,10 +9,10 @@ import (
const authenticationMiddlewareName = "AzureAuthentication"
func AuthMiddleware(tokenProvider AzureTokenProvider) httpclient.Middleware {
func AuthMiddleware(tokenProvider AzureTokenProvider, scopes []string) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenProvider.GetAccessToken(req.Context())
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
if err != nil {
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
}

View File

@ -19,14 +19,14 @@ type AccessToken struct {
ExpiresOn time.Time
}
type TokenCredential interface {
type TokenRetriever interface {
GetCacheKey() string
Init() error
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
}
type ConcurrentTokenCache interface {
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error)
}
func NewConcurrentTokenCache() ConcurrentTokenCache {
@ -37,7 +37,7 @@ type tokenCacheImpl struct {
cache sync.Map // of *credentialCacheEntry
}
type credentialCacheEntry struct {
credential TokenCredential
retriever TokenRetriever
credInit uint32
credMutex sync.Mutex
@ -45,19 +45,19 @@ type credentialCacheEntry struct {
}
type scopesCacheEntry struct {
credential TokenCredential
scopes []string
retriever TokenRetriever
scopes []string
cond *sync.Cond
refreshing bool
accessToken *AccessToken
}
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error) {
return c.getEntryFor(tokenRetriever).getAccessToken(ctx, scopes)
}
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
func (c *tokenCacheImpl) getEntryFor(credential TokenRetriever) *credentialCacheEntry {
var entry interface{}
var ok bool
@ -65,7 +65,7 @@ func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCach
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
credential: credential,
retriever: credential,
})
}
@ -87,8 +87,8 @@ func (c *credentialCacheEntry) ensureInitialized() error {
defer c.credMutex.Unlock()
if c.credInit == 0 {
// Initialize credential
err := c.credential.Init()
// Initialize retriever
err := c.retriever.Init()
if err != nil {
return err
}
@ -108,9 +108,9 @@ func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
credential: c.credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: c.retriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
})
}
@ -155,7 +155,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
var accessToken *AccessToken
// Safeguarding from panic caused by credential implementation
// Safeguarding from panic caused by retriever implementation
defer func() {
c.cond.L.Lock()
@ -169,7 +169,7 @@ func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken
c.cond.L.Unlock()
}()
token, err := c.credential.GetAccessToken(ctx, c.scopes)
token, err := c.retriever.GetAccessToken(ctx, c.scopes)
if err != nil {
return nil, err
}

View File

@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
)
type fakeCredential struct {
type fakeRetriever struct {
key string
initCalledTimes int
calledTimes int
@ -20,16 +20,16 @@ type fakeCredential struct {
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
func (c *fakeCredential) GetCacheKey() string {
func (c *fakeRetriever) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) Reset() {
func (c *fakeRetriever) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeCredential) Init() error {
func (c *fakeRetriever) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
@ -37,7 +37,7 @@ func (c *fakeCredential) Init() error {
return nil
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
func (c *fakeRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
return c.getAccessTokenFunc(ctx, scopes)
@ -52,15 +52,15 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
scopes1 := []string{"Scope1"}
scopes2 := []string{"Scope2"}
t.Run("should request access token from credential", func(t *testing.T) {
t.Run("should request access token from retriever", func(t *testing.T) {
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
tokenRetriever := &fakeRetriever{key: "retriever"}
token, err := cache.GetAccessToken(ctx, credential, scopes1)
token, err := cache.GetAccessToken(ctx, tokenRetriever, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token)
assert.Equal(t, "retriever-token-1", token)
assert.Equal(t, 1, credential.calledTimes)
assert.Equal(t, 1, tokenRetriever.calledTimes)
})
t.Run("should return cached token for same scopes", func(t *testing.T) {
@ -68,7 +68,7 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
var err error
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
credential := &fakeRetriever{key: "credential-1"}
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
@ -94,8 +94,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
var err error
cache := NewConcurrentTokenCache()
credential1 := &fakeCredential{key: "credential-1"}
credential2 := &fakeCredential{key: "credential-2"}
credential1 := &fakeRetriever{key: "credential-1"}
credential2 := &fakeRetriever{key: "credential-2"}
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
@ -119,8 +119,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when credential init returns error", func(t *testing.T) {
credential := &fakeCredential{
t.Run("when retriever init returns error", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
initFunc: func() error {
return errors.New("unable to initialize")
},
@ -128,7 +128,7 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
retriever: tokenRetriever,
}
err := cacheEntry.ensureInitialized()
@ -137,10 +137,10 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
})
t.Run("should call init again each time and return error", func(t *testing.T) {
credential.Reset()
tokenRetriever.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
retriever: tokenRetriever,
}
var err error
@ -153,13 +153,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, credential.initCalledTimes)
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
})
})
t.Run("when credential init returns error only once", func(t *testing.T) {
t.Run("when retriever init returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
tokenRetriever := &fakeRetriever{
initFunc: func() error {
times = times + 1
if times == 1 {
@ -169,9 +169,9 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
},
}
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
t.Run("should call retriever init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
retriever: tokenRetriever,
}
var err error
@ -184,52 +184,52 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, credential.initCalledTimes)
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
})
})
t.Run("when credential init panics", func(t *testing.T) {
credential := &fakeCredential{
t.Run("when retriever init panics", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call credential init again each time", func(t *testing.T) {
credential.Reset()
t.Run("should call retriever init again each time", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
retriever: tokenRetriever,
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, credential.initCalledTimes)
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
})
})
t.Run("when credential init panics only once", func(t *testing.T) {
t.Run("when retriever init panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
tokenRetriever := &fakeRetriever{
initFunc: func() error {
times = times + 1
if times == 1 {
@ -239,23 +239,23 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
},
}
t.Run("should call credential init again only while it panics", func(t *testing.T) {
t.Run("should call retriever init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
retriever: tokenRetriever,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
assert.Nil(t, recover(), "retriever not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
@ -263,13 +263,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
assert.Nil(t, recover(), "retriever not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, credential.initCalledTimes)
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
})
})
}
@ -279,8 +279,8 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
scopes := []string{"Scope1"}
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
credential := &fakeCredential{
t.Run("when retriever getAccessToken returns error", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
@ -289,9 +289,9 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
@ -300,13 +300,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
assert.Equal(t, "", accessToken)
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.Reset()
t.Run("should call retriever again each time and return error", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
@ -319,13 +319,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, credential.calledTimes)
assert.Equal(t, 3, tokenRetriever.calledTimes)
})
})
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
t.Run("when retriever getAccessToken returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
retriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
@ -337,11 +337,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
},
}
t.Run("should call credential again only while it returns error", func(t *testing.T) {
t.Run("should call retriever again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: retriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
@ -358,54 +358,54 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, credential.calledTimes)
assert.Equal(t, 2, retriever.calledTimes)
})
})
t.Run("when credential getAccessToken panics", func(t *testing.T) {
credential := &fakeCredential{
t.Run("when retriever getAccessToken panics", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.Reset()
t.Run("should call retriever again each time", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, credential.calledTimes)
assert.Equal(t, 3, tokenRetriever.calledTimes)
})
})
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
t.Run("when retriever getAccessToken panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
@ -416,11 +416,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
},
}
t.Run("should call credential again only while it panics", func(t *testing.T) {
t.Run("should call retriever again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
@ -428,14 +428,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
assert.Nil(t, recover(), "retriever not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
@ -444,14 +444,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
assert.Nil(t, recover(), "retriever not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, credential.calledTimes)
assert.Equal(t, 2, tokenRetriever.calledTimes)
})
})
}

View File

@ -4,12 +4,11 @@ import (
"context"
"crypto/sha256"
"fmt"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
)
var (
@ -17,72 +16,92 @@ var (
)
type AzureTokenProvider interface {
GetAccessToken(ctx context.Context) (string, error)
GetAccessToken(ctx context.Context, scopes []string) (string, error)
}
type tokenProviderImpl struct {
cfg *setting.Cfg
authParams *plugins.JwtTokenAuth
tokenRetriever TokenRetriever
}
func NewAzureAccessTokenProvider(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *tokenProviderImpl {
return &tokenProviderImpl{
cfg: cfg,
authParams: authParams,
func NewAzureAccessTokenProvider(cfg *setting.Cfg, credentials azcredentials.AzureCredentials) (AzureTokenProvider, error) {
if cfg == nil {
err := fmt.Errorf("parameter 'cfg' cannot be nil")
return nil, err
}
if credentials == nil {
err := fmt.Errorf("parameter 'credentials' cannot be nil")
return nil, err
}
}
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context) (string, error) {
var credential TokenCredential
var tokenRetriever TokenRetriever
if provider.isManagedIdentityCredential() {
if !provider.cfg.Azure.ManagedIdentityEnabled {
switch c := credentials.(type) {
case *azcredentials.AzureManagedIdentityCredentials:
if !cfg.Azure.ManagedIdentityEnabled {
err := fmt.Errorf("managed identity authentication is not enabled in Grafana config")
return "", err
return nil, err
} else {
credential = provider.getManagedIdentityCredential()
tokenRetriever = getManagedIdentityTokenRetriever(cfg, c)
}
} else {
credential = provider.getClientSecretCredential()
case *azcredentials.AzureClientSecretCredentials:
tokenRetriever = getClientSecretTokenRetriever(c)
default:
err := fmt.Errorf("credentials of type '%s' not supported by authentication provider", c.AzureAuthType())
return nil, err
}
accessToken, err := azureTokenCache.GetAccessToken(ctx, credential, provider.authParams.Scopes)
if err != nil {
tokenProvider := &tokenProviderImpl{
tokenRetriever: tokenRetriever,
}
return tokenProvider, nil
}
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context, scopes []string) (string, error) {
if ctx == nil {
err := fmt.Errorf("parameter 'ctx' cannot be nil")
return "", err
}
if scopes == nil {
err := fmt.Errorf("parameter 'scopes' cannot be nil")
return "", err
}
accessToken, err := azureTokenCache.GetAccessToken(ctx, provider.tokenRetriever, scopes)
if err != nil {
return "", err
}
return accessToken, nil
}
func (provider *tokenProviderImpl) isManagedIdentityCredential() bool {
authType := strings.ToLower(provider.authParams.Params["azure_auth_type"])
clientId := provider.authParams.Params["client_id"]
// Type of authentication being determined by the following logic:
// * If authType is set to 'msi' then user explicitly selected the managed identity authentication
// * If authType isn't set but other fields are configured then it's a datasource which was configured
// before managed identities where introduced, therefore use client secret authentication
// * If authType and other fields aren't set then it means the datasource never been configured
// and managed identity is the default authentication choice as long as managed identities are enabled
return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled)
func getManagedIdentityTokenRetriever(cfg *setting.Cfg, credentials *azcredentials.AzureManagedIdentityCredentials) TokenRetriever {
var clientId string
if credentials.ClientId != "" {
clientId = credentials.ClientId
} else {
clientId = cfg.Azure.ManagedIdentityClientId
}
return &managedIdentityTokenRetriever{
clientId: clientId,
}
}
func (provider *tokenProviderImpl) getManagedIdentityCredential() TokenCredential {
clientId := provider.cfg.Azure.ManagedIdentityClientId
return &managedIdentityCredential{clientId: clientId}
func getClientSecretTokenRetriever(credentials *azcredentials.AzureClientSecretCredentials) TokenRetriever {
var authority string
if credentials.Authority != "" {
authority = credentials.Authority
} else {
authority = resolveAuthorityForCloud(credentials.AzureCloud)
}
return &clientSecretTokenRetriever{
authority: authority,
tenantId: credentials.TenantId,
clientId: credentials.ClientId,
clientSecret: credentials.ClientSecret,
}
}
func (provider *tokenProviderImpl) getClientSecretCredential() TokenCredential {
authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"])
tenantId := provider.authParams.Params["tenant_id"]
clientId := provider.authParams.Params["client_id"]
clientSecret := provider.authParams.Params["client_secret"]
return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret}
}
func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string {
func resolveAuthorityForCloud(cloudName string) string {
// Known Azure clouds
switch cloudName {
case setting.AzurePublic:
@ -93,17 +112,17 @@ func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string
return azidentity.AzureGovernment
case setting.AzureGermany:
return azidentity.AzureGermany
default:
return ""
}
// Fallback to direct URL
return provider.authParams.Url
}
type managedIdentityCredential struct {
type managedIdentityTokenRetriever struct {
clientId string
credential azcore.TokenCredential
}
func (c *managedIdentityCredential) GetCacheKey() string {
func (c *managedIdentityTokenRetriever) GetCacheKey() string {
clientId := c.clientId
if clientId == "" {
clientId = "system"
@ -111,7 +130,7 @@ func (c *managedIdentityCredential) GetCacheKey() string {
return fmt.Sprintf("azure|msi|%s", clientId)
}
func (c *managedIdentityCredential) Init() error {
func (c *managedIdentityTokenRetriever) Init() error {
if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil {
return err
} else {
@ -120,7 +139,7 @@ func (c *managedIdentityCredential) Init() error {
}
}
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
func (c *managedIdentityTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err
@ -129,7 +148,7 @@ func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes [
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
}
type clientSecretCredential struct {
type clientSecretTokenRetriever struct {
authority string
tenantId string
clientId string
@ -137,11 +156,11 @@ type clientSecretCredential struct {
credential azcore.TokenCredential
}
func (c *clientSecretCredential) GetCacheKey() string {
func (c *clientSecretTokenRetriever) GetCacheKey() string {
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
}
func (c *clientSecretCredential) Init() error {
func (c *clientSecretTokenRetriever) Init() error {
options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority}
if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil {
return err
@ -151,7 +170,7 @@ func (c *clientSecretCredential) Init() error {
}
}
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
func (c *clientSecretTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err

View File

@ -4,125 +4,30 @@ import (
"context"
"testing"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var getAccessTokenFunc func(credential TokenCredential, scopes []string)
var getAccessTokenFunc func(credential TokenRetriever, scopes []string)
type tokenCacheFake struct{}
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenRetriever, scopes []string) (string, error) {
getAccessTokenFunc(credential, scopes)
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
}
func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
cfg := &setting.Cfg{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
},
Params: map[string]string{
"azure_auth_type": "",
"azure_cloud": "AzureCloud",
"tenant_id": "",
"client_id": "",
"client_secret": "",
},
}
provider := NewAzureAccessTokenProvider(cfg, authParams)
t.Run("when managed identities enabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = true
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "msi",
}
assert.True(t, provider.isManagedIdentityCredential())
})
t.Run("should be client secret if auth type is client secret", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "clientsecret",
}
assert.False(t, provider.isManagedIdentityCredential())
})
t.Run("should be managed identity if datasource not configured", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "",
"tenant_id": "",
"client_id": "",
"client_secret": "",
}
assert.True(t, provider.isManagedIdentityCredential())
})
t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "",
"tenant_id": "06da9207-bdd9-4558-aee4-377450893cb4",
"client_id": "b8c58fe8-1fca-4e30-a0a8-b44d0e5f70d6",
"client_secret": "9bcd4434-824f-4887-a8a8-94c287bf0a7b",
}
assert.False(t, provider.isManagedIdentityCredential())
})
})
t.Run("when managed identities disabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = false
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "msi",
}
assert.True(t, provider.isManagedIdentityCredential())
})
t.Run("should be client secret if datasource not configured", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "",
"tenant_id": "",
"client_id": "",
"client_secret": "",
}
assert.False(t, provider.isManagedIdentityCredential())
})
})
}
func TestAzureTokenProvider_getAccessToken(t *testing.T) {
func TestAzureTokenProvider_GetAccessToken(t *testing.T) {
ctx := context.Background()
cfg := &setting.Cfg{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
},
Params: map[string]string{
"azure_auth_type": "",
"azure_cloud": "AzureCloud",
"tenant_id": "",
"client_id": "",
"client_secret": "",
},
scopes := []string{
"https://management.azure.com/.default",
}
provider := NewAzureAccessTokenProvider(cfg, authParams)
original := azureTokenCache
azureTokenCache = &tokenCacheFake{}
t.Cleanup(func() { azureTokenCache = original })
@ -130,29 +35,31 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
t.Run("when managed identities enabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = true
t.Run("should resolve managed identity credential if auth type is managed identity", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "msi",
t.Run("should resolve managed identity retriever if auth type is managed identity", func(t *testing.T) {
credentials := &azcredentials.AzureManagedIdentityCredentials{}
provider, err := NewAzureAccessTokenProvider(cfg, credentials)
require.NoError(t, err)
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
assert.IsType(t, &managedIdentityTokenRetriever{}, credential)
}
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
assert.IsType(t, &managedIdentityCredential{}, credential)
}
_, err := provider.GetAccessToken(ctx)
_, err = provider.GetAccessToken(ctx, scopes)
require.NoError(t, err)
})
t.Run("should resolve client secret credential if auth type is client secret", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "clientsecret",
t.Run("should resolve client secret retriever if auth type is client secret", func(t *testing.T) {
credentials := &azcredentials.AzureClientSecretCredentials{}
provider, err := NewAzureAccessTokenProvider(cfg, credentials)
require.NoError(t, err)
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
assert.IsType(t, &clientSecretTokenRetriever{}, credential)
}
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
assert.IsType(t, &clientSecretCredential{}, credential)
}
_, err := provider.GetAccessToken(ctx)
_, err = provider.GetAccessToken(ctx, scopes)
require.NoError(t, err)
})
})
@ -161,47 +68,61 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = false
t.Run("should return error if auth type is managed identity", func(t *testing.T) {
authParams.Params = map[string]string{
"azure_auth_type": "msi",
}
credentials := &azcredentials.AzureManagedIdentityCredentials{}
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
assert.Fail(t, "token cache not expected to be called")
}
_, err := provider.GetAccessToken(ctx)
require.Error(t, err)
_, err := NewAzureAccessTokenProvider(cfg, credentials)
assert.Error(t, err, "managed identity authentication is not enabled in Grafana config")
})
})
}
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
cfg := &setting.Cfg{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
},
Params: map[string]string{
"azure_auth_type": "",
"azure_cloud": "AzureCloud",
"tenant_id": "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
"client_id": "1af7c188-e5b6-4f96-81b8-911761bdd459",
"client_secret": "0416d95e-8af8-472c-aaa3-15c93c46080a",
},
credentials := &azcredentials.AzureClientSecretCredentials{
AzureCloud: setting.AzurePublic,
Authority: "",
TenantId: "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
ClientId: "1af7c188-e5b6-4f96-81b8-911761bdd459",
ClientSecret: "0416d95e-8af8-472c-aaa3-15c93c46080a",
}
provider := NewAzureAccessTokenProvider(cfg, authParams)
t.Run("should return clientSecretTokenRetriever with values", func(t *testing.T) {
result := getClientSecretTokenRetriever(credentials)
assert.IsType(t, &clientSecretTokenRetriever{}, result)
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
result := provider.getClientSecretCredential()
assert.IsType(t, &clientSecretCredential{}, result)
credential := (result).(*clientSecretCredential)
credential := (result).(*clientSecretTokenRetriever)
assert.Equal(t, "https://login.microsoftonline.com/", credential.authority)
assert.Equal(t, "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", credential.tenantId)
assert.Equal(t, "1af7c188-e5b6-4f96-81b8-911761bdd459", credential.clientId)
assert.Equal(t, "0416d95e-8af8-472c-aaa3-15c93c46080a", credential.clientSecret)
})
t.Run("authority should selected based on cloud", func(t *testing.T) {
originalCloud := credentials.AzureCloud
defer func() { credentials.AzureCloud = originalCloud }()
credentials.AzureCloud = setting.AzureChina
result := getClientSecretTokenRetriever(credentials)
assert.IsType(t, &clientSecretTokenRetriever{}, result)
credential := (result).(*clientSecretTokenRetriever)
assert.Equal(t, "https://login.chinacloudapi.cn/", credential.authority)
})
t.Run("explicitly set authority should have priority over cloud", func(t *testing.T) {
originalCloud := credentials.AzureCloud
defer func() { credentials.AzureCloud = originalCloud }()
credentials.AzureCloud = setting.AzureChina
credentials.Authority = "https://another.com/"
result := getClientSecretTokenRetriever(credentials)
assert.IsType(t, &clientSecretTokenRetriever{}, result)
credential := (result).(*clientSecretTokenRetriever)
assert.Equal(t, "https://another.com/", credential.authority)
})
}

View File

@ -11,12 +11,14 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
)
const (
@ -44,20 +46,15 @@ type Service struct {
}
type azureMonitorSettings struct {
SubscriptionId string `json:"subscriptionId"`
LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"`
AppInsightsAppId string `json:"appInsightsAppId"`
AzureLogAnalyticsSameAs bool `json:"azureLogAnalyticsSameAs"`
ClientId string `json:"clientId"`
CloudName string `json:"cloudName"`
LogAnalyticsClientId string `json:"logAnalyticsClientId"`
LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"`
LogAnalyticsSubscriptionId string `json:"logAnalyticsSubscriptionId"`
LogAnalyticsTenantId string `json:"logAnalyticsTenantId"`
SubscriptionId string `json:"subscriptionId"`
TenantId string `json:"tenantId"`
AzureAuthType string `json:"azureAuthType,omitempty"`
}
type datasourceInfo struct {
Cloud string
Credentials azcredentials.AzureCredentials
Settings azureMonitorSettings
Services map[string]datasourceService
Routes map[string]azRoute
@ -74,10 +71,15 @@ type datasourceService struct {
HTTPClient *http.Client
}
func NewInstanceSettings() datasource.InstanceFactoryFunc {
func NewInstanceSettings(cfg *setting.Cfg) datasource.InstanceFactoryFunc {
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
jsonData := map[string]interface{}{}
err := json.Unmarshal(settings.JSONData, &jsonData)
jsonData, err := simplejson.NewJson(settings.JSONData)
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
jsonDataObj := map[string]interface{}{}
err = json.Unmarshal(settings.JSONData, &jsonDataObj)
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
@ -87,20 +89,34 @@ func NewInstanceSettings() datasource.InstanceFactoryFunc {
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
cloud, err := getAzureCloud(cfg, jsonData)
if err != nil {
return nil, fmt.Errorf("error getting credentials: %w", err)
}
credentials, err := getAzureCredentials(cfg, jsonData, settings.DecryptedSecureJSONData)
if err != nil {
return nil, fmt.Errorf("error getting credentials: %w", err)
}
httpCliOpts, err := settings.HTTPClientOptions()
if err != nil {
return nil, fmt.Errorf("error getting http options: %w", err)
}
model := datasourceInfo{
Cloud: cloud,
Credentials: credentials,
Settings: azMonitorSettings,
JSONData: jsonData,
JSONData: jsonDataObj,
DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
DatasourceID: settings.ID,
Services: map[string]datasourceService{},
Routes: routes[azMonitorSettings.CloudName],
Routes: routes[cloud],
HTTPCliOpts: httpCliOpts,
}
return model, nil
}
}
@ -141,7 +157,7 @@ func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors ma
}
func (s *Service) Init() error {
im := datasource.NewInstanceManager(NewInstanceSettings())
im := datasource.NewInstanceManager(NewInstanceSettings(s.Cfg))
executors := map[string]azDatasourceExecutor{
azureMonitor: &AzureMonitorDatasource{},
appInsights: &ApplicationInsightsDatasource{},

View File

@ -9,6 +9,7 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/stretchr/testify/require"
)
@ -22,14 +23,16 @@ func TestNewInstanceSettings(t *testing.T) {
{
name: "creates an instance",
settings: backend.DataSourceInstanceSettings{
JSONData: []byte(`{"cloudName":"azuremonitor"}`),
JSONData: []byte(`{"azureAuthType":"msi"}`),
DecryptedSecureJSONData: map[string]string{"key": "value"},
ID: 40,
},
expectedModel: datasourceInfo{
Settings: azureMonitorSettings{CloudName: "azuremonitor"},
Routes: routes["azuremonitor"],
JSONData: map[string]interface{}{"cloudName": string("azuremonitor")},
Cloud: setting.AzurePublic,
Credentials: &azcredentials.AzureManagedIdentityCredentials{},
Settings: azureMonitorSettings{},
Routes: routes[setting.AzurePublic],
JSONData: map[string]interface{}{"azureAuthType": "msi"},
DatasourceID: 40,
DecryptedSecureJSONData: map[string]string{"key": "value"},
},
@ -37,9 +40,15 @@ func TestNewInstanceSettings(t *testing.T) {
},
}
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzurePublic,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewInstanceSettings()
factory := NewInstanceSettings(cfg)
instance, err := factory(tt.settings)
tt.Err(t, err)
if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) {

View File

@ -1,12 +1,11 @@
package azuremonitor
import (
"net/http"
"fmt"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
)
// Azure cloud names specific to Azure Monitor
@ -17,40 +16,107 @@ const (
azureMonitorGermany = "germanyazuremonitor"
)
// Azure cloud query types
const (
azureMonitor = "Azure Monitor"
appInsights = "Application Insights"
azureLogAnalytics = "Azure Log Analytics"
insightsAnalytics = "Insights Analytics"
azureResourceGraph = "Azure Resource Graph"
)
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
if len(route.Scopes) > 0 {
tokenAuth := &plugins.JwtTokenAuth{
Url: route.URL,
Scopes: route.Scopes,
Params: map[string]string{
"azure_auth_type": model.Settings.AzureAuthType,
"azure_cloud": cfg.Azure.Cloud,
"tenant_id": model.Settings.TenantId,
"client_id": model.Settings.ClientId,
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
},
}
tokenProvider := aztokenprovider.NewAzureAccessTokenProvider(cfg, tokenAuth)
return httpclient.NewProvider(httpclient.ProviderOptions{
Middlewares: []httpclient.Middleware{
aztokenprovider.AuthMiddleware(tokenProvider),
},
})
func getAuthType(cfg *setting.Cfg, jsonData *simplejson.Json) string {
if azureAuthType := jsonData.Get("azureAuthType").MustString(); azureAuthType != "" {
return azureAuthType
} else {
return httpclient.NewProvider()
tenantId := jsonData.Get("tenantId").MustString()
clientId := jsonData.Get("clientId").MustString()
// If authentication type isn't explicitly specified and datasource has client credentials,
// then this is existing datasource which is configured for app registration (client secret)
if tenantId != "" && clientId != "" {
return azcredentials.AzureAuthClientSecret
}
// For newly created datasource with no configuration, managed identity is the default authentication type
// if they are enabled in Grafana config
if cfg.Azure.ManagedIdentityEnabled {
return azcredentials.AzureAuthManagedIdentity
} else {
return azcredentials.AzureAuthClientSecret
}
}
}
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
model.HTTPCliOpts.Headers = route.Headers
return httpClientProvider(route, model, cfg).New(model.HTTPCliOpts)
func getDefaultAzureCloud(cfg *setting.Cfg) (string, error) {
// Allow only known cloud names
cloudName := cfg.Azure.Cloud
switch cloudName {
case setting.AzurePublic:
return setting.AzurePublic, nil
case setting.AzureChina:
return setting.AzureChina, nil
case setting.AzureUSGovernment:
return setting.AzureUSGovernment, nil
case setting.AzureGermany:
return setting.AzureGermany, nil
case "":
// Not set cloud defaults to public
return setting.AzurePublic, nil
default:
err := fmt.Errorf("the cloud '%s' not supported", cloudName)
return "", err
}
}
func normalizeAzureCloud(cloudName string) (string, error) {
switch cloudName {
case azureMonitorPublic:
return setting.AzurePublic, nil
case azureMonitorChina:
return setting.AzureChina, nil
case azureMonitorUSGovernment:
return setting.AzureUSGovernment, nil
case azureMonitorGermany:
return setting.AzureGermany, nil
default:
err := fmt.Errorf("the cloud '%s' not supported", cloudName)
return "", err
}
}
func getAzureCloud(cfg *setting.Cfg, jsonData *simplejson.Json) (string, error) {
authType := getAuthType(cfg, jsonData)
switch authType {
case azcredentials.AzureAuthManagedIdentity:
// In case of managed identity, the cloud is always same as where Grafana is hosted
return getDefaultAzureCloud(cfg)
case azcredentials.AzureAuthClientSecret:
if cloud := jsonData.Get("cloudName").MustString(); cloud != "" {
return normalizeAzureCloud(cloud)
} else {
return getDefaultAzureCloud(cfg)
}
default:
err := fmt.Errorf("the authentication type '%s' not supported", authType)
return "", err
}
}
func getAzureCredentials(cfg *setting.Cfg, jsonData *simplejson.Json, secureJsonData map[string]string) (azcredentials.AzureCredentials, error) {
authType := getAuthType(cfg, jsonData)
switch authType {
case azcredentials.AzureAuthManagedIdentity:
credentials := &azcredentials.AzureManagedIdentityCredentials{}
return credentials, nil
case azcredentials.AzureAuthClientSecret:
cloud, err := getAzureCloud(cfg, jsonData)
if err != nil {
return nil, err
}
credentials := &azcredentials.AzureClientSecretCredentials{
AzureCloud: cloud,
TenantId: jsonData.Get("tenantId").MustString(),
ClientId: jsonData.Get("clientId").MustString(),
ClientSecret: secureJsonData["clientSecret"],
}
return credentials, nil
default:
err := fmt.Errorf("the authentication type '%s' not supported", authType)
return nil, err
}
}

View File

@ -3,50 +3,195 @@ package azuremonitor
import (
"testing"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_httpCliProvider(t *testing.T) {
func TestCredentials_getAuthType(t *testing.T) {
cfg := &setting.Cfg{}
model := datasourceInfo{
Settings: azureMonitorSettings{},
DecryptedSecureJSONData: map[string]string{"clientSecret": "content"},
}
tests := []struct {
name string
route azRoute
expectedMiddlewares int
Err require.ErrorAssertionFunc
}{
{
name: "creates an HTTP client with a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{"http://route/.default"},
},
expectedMiddlewares: 1,
Err: require.NoError,
},
{
name: "creates an HTTP client without a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{},
},
// httpclient.NewProvider returns a client with 2 middlewares by default
expectedMiddlewares: 2,
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli := httpClientProvider(tt.route, model, cfg)
// Cannot test that the cli middleware works properly since the azcore sdk
// rejects the TLS certs (if provided)
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
}
t.Run("when managed identities enabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = true
t.Run("should be client secret if auth type is set to client secret", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthClientSecret,
})
authType := getAuthType(cfg, jsonData)
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
})
}
t.Run("should be managed identity if datasource not configured", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": "",
})
authType := getAuthType(cfg, jsonData)
assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType)
})
t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": "",
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
})
authType := getAuthType(cfg, jsonData)
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
})
})
t.Run("when managed identities disabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = false
t.Run("should be managed identity if auth type is set to managed identity", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
})
authType := getAuthType(cfg, jsonData)
assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType)
})
t.Run("should be client secret if datasource not configured", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": "",
})
authType := getAuthType(cfg, jsonData)
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
})
})
}
func TestCredentials_getAzureCloud(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzureChina,
},
}
t.Run("when auth type is managed identity", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
"cloudName": azureMonitorGermany,
})
t.Run("should be from server configuration regardless of datasource value", func(t *testing.T) {
cloud, err := getAzureCloud(cfg, jsonData)
require.NoError(t, err)
assert.Equal(t, setting.AzureChina, cloud)
})
t.Run("should be public if not set in server configuration", func(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: "",
},
}
cloud, err := getAzureCloud(cfg, jsonData)
require.NoError(t, err)
assert.Equal(t, setting.AzurePublic, cloud)
})
})
t.Run("when auth type is client secret", func(t *testing.T) {
t.Run("should be from datasource value normalized to known cloud name", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthClientSecret,
"cloudName": azureMonitorGermany,
})
cloud, err := getAzureCloud(cfg, jsonData)
require.NoError(t, err)
assert.Equal(t, setting.AzureGermany, cloud)
})
t.Run("should be from server configuration if not set in datasource", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthClientSecret,
"cloudName": "",
})
cloud, err := getAzureCloud(cfg, jsonData)
require.NoError(t, err)
assert.Equal(t, setting.AzureChina, cloud)
})
})
}
func TestCredentials_getAzureCredentials(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzureChina,
},
}
secureJsonData := map[string]string{
"clientSecret": "59e3498f-eb12-4943-b8f0-a5aa42640058",
}
t.Run("when auth type is managed identity", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
"cloudName": azureMonitorGermany,
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
})
t.Run("should return managed identity credentials", func(t *testing.T) {
credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData)
require.NoError(t, err)
require.IsType(t, &azcredentials.AzureManagedIdentityCredentials{}, credentials)
msiCredentials := credentials.(*azcredentials.AzureManagedIdentityCredentials)
// Azure Monitor datasource doesn't support user-assigned managed identities (ClientId is always empty)
assert.Equal(t, "", msiCredentials.ClientId)
})
})
t.Run("when auth type is client secret", func(t *testing.T) {
jsonData := simplejson.NewFromAny(map[string]interface{}{
"azureAuthType": azcredentials.AzureAuthClientSecret,
"cloudName": azureMonitorGermany,
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
})
t.Run("should return client secret credentials", func(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzureChina,
},
}
credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData)
require.NoError(t, err)
require.IsType(t, &azcredentials.AzureClientSecretCredentials{}, credentials)
clientSecretCredentials := credentials.(*azcredentials.AzureClientSecretCredentials)
assert.Equal(t, setting.AzureGermany, clientSecretCredentials.AzureCloud)
assert.Equal(t, "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", clientSecretCredentials.TenantId)
assert.Equal(t, "849ccbb0-92eb-4226-b228-ef391abd8fe6", clientSecretCredentials.ClientId)
assert.Equal(t, "59e3498f-eb12-4943-b8f0-a5aa42640058", clientSecretCredentials.ClientSecret)
// Azure Monitor datasource doesn't support custom IdP authorities (Authority is always empty)
assert.Equal(t, "", clientSecretCredentials.Authority)
})
})
}

View File

@ -0,0 +1,41 @@
package azuremonitor
import (
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
)
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*httpclient.Provider, error) {
var clientProvider *httpclient.Provider
if len(route.Scopes) > 0 {
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, model.Credentials)
if err != nil {
return nil, err
}
clientProvider = httpclient.NewProvider(httpclient.ProviderOptions{
Middlewares: []httpclient.Middleware{
aztokenprovider.AuthMiddleware(tokenProvider, route.Scopes),
},
})
} else {
clientProvider = httpclient.NewProvider()
}
return clientProvider, nil
}
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
model.HTTPCliOpts.Headers = route.Headers
clientProvider, err := httpClientProvider(route, model, cfg)
if err != nil {
return nil, err
}
return clientProvider.New(model.HTTPCliOpts)
}

View File

@ -0,0 +1,54 @@
package azuremonitor
import (
"testing"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/stretchr/testify/require"
)
func Test_httpCliProvider(t *testing.T) {
cfg := &setting.Cfg{}
model := datasourceInfo{
Credentials: &azcredentials.AzureClientSecretCredentials{},
}
tests := []struct {
name string
route azRoute
expectedMiddlewares int
Err require.ErrorAssertionFunc
}{
{
name: "creates an HTTP client with a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{"http://route/.default"},
},
expectedMiddlewares: 1,
Err: require.NoError,
},
{
name: "creates an HTTP client without a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{},
},
// httpclient.NewProvider returns a client with 2 middlewares by default
expectedMiddlewares: 2,
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli, err := httpClientProvider(tt.route, model, cfg)
require.NoError(t, err)
// Cannot test that the cli middleware works properly since the azcore sdk
// rejects the TLS certs (if provided)
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
}
})
}
}

View File

@ -1,5 +1,16 @@
package azuremonitor
import "github.com/grafana/grafana/pkg/setting"
// Azure cloud query types
const (
azureMonitor = "Azure Monitor"
appInsights = "Application Insights"
azureLogAnalytics = "Azure Log Analytics"
insightsAnalytics = "Insights Analytics"
azureResourceGraph = "Azure Resource Graph"
)
type azRoute struct {
URL string
Scopes []string
@ -64,22 +75,22 @@ var (
// The different Azure routes are identified by its cloud (e.g. public or gov)
// and the service to query (e.g. Azure Monitor or Azure Log Analytics)
routes = map[string]map[string]azRoute{
azureMonitorPublic: {
setting.AzurePublic: {
azureMonitor: azManagement,
azureLogAnalytics: azLogAnalytics,
azureResourceGraph: azManagement,
appInsights: azAppInsights,
insightsAnalytics: azAppInsights,
},
azureMonitorUSGovernment: {
setting.AzureUSGovernment: {
azureMonitor: azUSGovManagement,
azureLogAnalytics: azUSGovLogAnalytics,
azureResourceGraph: azUSGovManagement,
},
azureMonitorGermany: {
setting.AzureGermany: {
azureMonitor: azGermanyManagement,
},
azureMonitorChina: {
setting.AzureChina: {
azureMonitor: azChinaManagement,
azureLogAnalytics: azChinaLogAnalytics,
azureResourceGraph: azChinaManagement,