mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AzureMonitor: strongly-typed AzureCredentials and correct resolution of auth type and cloud (#36284)
This commit is contained in:
parent
719e78f333
commit
89ba607382
@ -92,8 +92,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
|
||||
if tokenAuth == nil {
|
||||
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
|
||||
}
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
return provider, nil
|
||||
return newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
|
||||
case "gce":
|
||||
if jwtTokenAuth == nil {
|
||||
|
@ -2,24 +2,58 @@ package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
ctx context.Context
|
||||
tokenProvider aztokenprovider.AzureTokenProvider
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
|
||||
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) (*azureAccessTokenProvider, error) {
|
||||
credentials := getAzureCredentials(cfg, authParams)
|
||||
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &azureAccessTokenProvider{
|
||||
ctx: ctx,
|
||||
tokenProvider: aztokenprovider.NewAzureAccessTokenProvider(cfg, authParams),
|
||||
}
|
||||
tokenProvider: tokenProvider,
|
||||
scopes: authParams.Scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
return provider.tokenProvider.GetAccessToken(provider.ctx)
|
||||
return provider.tokenProvider.GetAccessToken(provider.ctx, provider.scopes)
|
||||
}
|
||||
|
||||
func getAzureCredentials(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) azcredentials.AzureCredentials {
|
||||
authType := strings.ToLower(authParams.Params["azure_auth_type"])
|
||||
clientId := authParams.Params["client_id"]
|
||||
|
||||
// Type of authentication being determined by the following logic:
|
||||
// * If authType is set to 'msi' then user explicitly selected the managed identity authentication
|
||||
// * If authType isn't set but other fields are configured then it's a datasource which was configured
|
||||
// before managed identities where introduced, therefore use client secret authentication
|
||||
// * If authType and other fields aren't set then it means the datasource never been configured
|
||||
// and managed identity is the default authentication choice as long as managed identities are enabled
|
||||
isManagedIdentity := authType == "msi" || (authType == "" && clientId == "" && cfg.Azure.ManagedIdentityEnabled)
|
||||
|
||||
if isManagedIdentity {
|
||||
return &azcredentials.AzureManagedIdentityCredentials{}
|
||||
} else {
|
||||
return &azcredentials.AzureClientSecretCredentials{
|
||||
AzureCloud: authParams.Params["azure_cloud"],
|
||||
Authority: authParams.Url,
|
||||
TenantId: authParams.Params["tenant_id"],
|
||||
ClientId: authParams.Params["client_id"],
|
||||
ClientSecret: authParams.Params["client_secret"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
30
pkg/tsdb/azuremonitor/azcredentials/credentials.go
Normal file
30
pkg/tsdb/azuremonitor/azcredentials/credentials.go
Normal file
@ -0,0 +1,30 @@
|
||||
package azcredentials
|
||||
|
||||
const (
|
||||
AzureAuthManagedIdentity = "msi"
|
||||
AzureAuthClientSecret = "clientsecret"
|
||||
)
|
||||
|
||||
type AzureCredentials interface {
|
||||
AzureAuthType() string
|
||||
}
|
||||
|
||||
type AzureManagedIdentityCredentials struct {
|
||||
ClientId string
|
||||
}
|
||||
|
||||
type AzureClientSecretCredentials struct {
|
||||
AzureCloud string
|
||||
Authority string
|
||||
TenantId string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
func (credentials *AzureManagedIdentityCredentials) AzureAuthType() string {
|
||||
return AzureAuthManagedIdentity
|
||||
}
|
||||
|
||||
func (credentials *AzureClientSecretCredentials) AzureAuthType() string {
|
||||
return AzureAuthClientSecret
|
||||
}
|
@ -9,10 +9,10 @@ import (
|
||||
|
||||
const authenticationMiddlewareName = "AzureAuthentication"
|
||||
|
||||
func AuthMiddleware(tokenProvider AzureTokenProvider) httpclient.Middleware {
|
||||
func AuthMiddleware(tokenProvider AzureTokenProvider, scopes []string) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken(req.Context())
|
||||
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
|
||||
}
|
||||
|
@ -19,14 +19,14 @@ type AccessToken struct {
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenCredential interface {
|
||||
type TokenRetriever interface {
|
||||
GetCacheKey() string
|
||||
Init() error
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
type ConcurrentTokenCache interface {
|
||||
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
|
||||
GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache {
|
||||
@ -37,7 +37,7 @@ type tokenCacheImpl struct {
|
||||
cache sync.Map // of *credentialCacheEntry
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
credential TokenCredential
|
||||
retriever TokenRetriever
|
||||
|
||||
credInit uint32
|
||||
credMutex sync.Mutex
|
||||
@ -45,19 +45,19 @@ type credentialCacheEntry struct {
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
credential TokenCredential
|
||||
scopes []string
|
||||
retriever TokenRetriever
|
||||
scopes []string
|
||||
|
||||
cond *sync.Cond
|
||||
refreshing bool
|
||||
accessToken *AccessToken
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error) {
|
||||
return c.getEntryFor(tokenRetriever).getAccessToken(ctx, scopes)
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenRetriever) *credentialCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
@ -65,7 +65,7 @@ func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCach
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: credential,
|
||||
})
|
||||
}
|
||||
|
||||
@ -87,8 +87,8 @@ func (c *credentialCacheEntry) ensureInitialized() error {
|
||||
defer c.credMutex.Unlock()
|
||||
|
||||
if c.credInit == 0 {
|
||||
// Initialize credential
|
||||
err := c.credential.Init()
|
||||
// Initialize retriever
|
||||
err := c.retriever.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -108,9 +108,9 @@ func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
|
||||
credential: c.credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: c.retriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
@ -155,7 +155,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
|
||||
var accessToken *AccessToken
|
||||
|
||||
// Safeguarding from panic caused by credential implementation
|
||||
// Safeguarding from panic caused by retriever implementation
|
||||
defer func() {
|
||||
c.cond.L.Lock()
|
||||
|
||||
@ -169,7 +169,7 @@ func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken
|
||||
c.cond.L.Unlock()
|
||||
}()
|
||||
|
||||
token, err := c.credential.GetAccessToken(ctx, c.scopes)
|
||||
token, err := c.retriever.GetAccessToken(ctx, c.scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeCredential struct {
|
||||
type fakeRetriever struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
@ -20,16 +20,16 @@ type fakeCredential struct {
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetCacheKey() string {
|
||||
func (c *fakeRetriever) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Reset() {
|
||||
func (c *fakeRetriever) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Init() error {
|
||||
func (c *fakeRetriever) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
@ -37,7 +37,7 @@ func (c *fakeCredential) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
func (c *fakeRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
return c.getAccessTokenFunc(ctx, scopes)
|
||||
@ -52,15 +52,15 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
scopes1 := []string{"Scope1"}
|
||||
scopes2 := []string{"Scope2"}
|
||||
|
||||
t.Run("should request access token from credential", func(t *testing.T) {
|
||||
t.Run("should request access token from retriever", func(t *testing.T) {
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
tokenRetriever := &fakeRetriever{key: "retriever"}
|
||||
|
||||
token, err := cache.GetAccessToken(ctx, credential, scopes1)
|
||||
token, err := cache.GetAccessToken(ctx, tokenRetriever, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token)
|
||||
assert.Equal(t, "retriever-token-1", token)
|
||||
|
||||
assert.Equal(t, 1, credential.calledTimes)
|
||||
assert.Equal(t, 1, tokenRetriever.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) {
|
||||
@ -68,7 +68,7 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
credential := &fakeRetriever{key: "credential-1"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
@ -94,8 +94,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential1 := &fakeCredential{key: "credential-1"}
|
||||
credential2 := &fakeCredential{key: "credential-2"}
|
||||
credential1 := &fakeRetriever{key: "credential-1"}
|
||||
credential2 := &fakeRetriever{key: "credential-2"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
@ -119,8 +119,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when credential init returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
t.Run("when retriever init returns error", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
@ -128,7 +128,7 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
@ -137,10 +137,10 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -153,13 +153,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) {
|
||||
t.Run("when retriever init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
@ -169,9 +169,9 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
|
||||
t.Run("should call retriever init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -184,52 +184,52 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
t.Run("when retriever init panics", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
t.Run("should call retriever init again each time", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) {
|
||||
t.Run("when retriever init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
@ -239,23 +239,23 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) {
|
||||
t.Run("should call retriever init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
@ -263,13 +263,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -279,8 +279,8 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
t.Run("when retriever getAccessToken returns error", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
@ -289,9 +289,9 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx)
|
||||
@ -300,13 +300,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
assert.Equal(t, "", accessToken)
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
t.Run("should call retriever again each time and return error", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -319,13 +319,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
assert.Equal(t, 3, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
|
||||
t.Run("when retriever getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
retriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
@ -337,11 +337,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it returns error", func(t *testing.T) {
|
||||
t.Run("should call retriever again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: retriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
@ -358,54 +358,54 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
assert.Equal(t, 2, retriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
t.Run("when retriever getAccessToken panics", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
t.Run("should call retriever again each time", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
assert.Equal(t, 3, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
|
||||
t.Run("when retriever getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
@ -416,11 +416,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it panics", func(t *testing.T) {
|
||||
t.Run("should call retriever again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
@ -428,14 +428,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
@ -444,14 +444,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
assert.Equal(t, 2, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -4,12 +4,11 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -17,72 +16,92 @@ var (
|
||||
)
|
||||
|
||||
type AzureTokenProvider interface {
|
||||
GetAccessToken(ctx context.Context) (string, error)
|
||||
GetAccessToken(ctx context.Context, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
type tokenProviderImpl struct {
|
||||
cfg *setting.Cfg
|
||||
authParams *plugins.JwtTokenAuth
|
||||
tokenRetriever TokenRetriever
|
||||
}
|
||||
|
||||
func NewAzureAccessTokenProvider(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *tokenProviderImpl {
|
||||
return &tokenProviderImpl{
|
||||
cfg: cfg,
|
||||
authParams: authParams,
|
||||
func NewAzureAccessTokenProvider(cfg *setting.Cfg, credentials azcredentials.AzureCredentials) (AzureTokenProvider, error) {
|
||||
if cfg == nil {
|
||||
err := fmt.Errorf("parameter 'cfg' cannot be nil")
|
||||
return nil, err
|
||||
}
|
||||
if credentials == nil {
|
||||
err := fmt.Errorf("parameter 'credentials' cannot be nil")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context) (string, error) {
|
||||
var credential TokenCredential
|
||||
var tokenRetriever TokenRetriever
|
||||
|
||||
if provider.isManagedIdentityCredential() {
|
||||
if !provider.cfg.Azure.ManagedIdentityEnabled {
|
||||
switch c := credentials.(type) {
|
||||
case *azcredentials.AzureManagedIdentityCredentials:
|
||||
if !cfg.Azure.ManagedIdentityEnabled {
|
||||
err := fmt.Errorf("managed identity authentication is not enabled in Grafana config")
|
||||
return "", err
|
||||
return nil, err
|
||||
} else {
|
||||
credential = provider.getManagedIdentityCredential()
|
||||
tokenRetriever = getManagedIdentityTokenRetriever(cfg, c)
|
||||
}
|
||||
} else {
|
||||
credential = provider.getClientSecretCredential()
|
||||
case *azcredentials.AzureClientSecretCredentials:
|
||||
tokenRetriever = getClientSecretTokenRetriever(c)
|
||||
default:
|
||||
err := fmt.Errorf("credentials of type '%s' not supported by authentication provider", c.AzureAuthType())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := azureTokenCache.GetAccessToken(ctx, credential, provider.authParams.Scopes)
|
||||
if err != nil {
|
||||
tokenProvider := &tokenProviderImpl{
|
||||
tokenRetriever: tokenRetriever,
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
if ctx == nil {
|
||||
err := fmt.Errorf("parameter 'ctx' cannot be nil")
|
||||
return "", err
|
||||
}
|
||||
if scopes == nil {
|
||||
err := fmt.Errorf("parameter 'scopes' cannot be nil")
|
||||
return "", err
|
||||
}
|
||||
|
||||
accessToken, err := azureTokenCache.GetAccessToken(ctx, provider.tokenRetriever, scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) isManagedIdentityCredential() bool {
|
||||
authType := strings.ToLower(provider.authParams.Params["azure_auth_type"])
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
|
||||
// Type of authentication being determined by the following logic:
|
||||
// * If authType is set to 'msi' then user explicitly selected the managed identity authentication
|
||||
// * If authType isn't set but other fields are configured then it's a datasource which was configured
|
||||
// before managed identities where introduced, therefore use client secret authentication
|
||||
// * If authType and other fields aren't set then it means the datasource never been configured
|
||||
// and managed identity is the default authentication choice as long as managed identities are enabled
|
||||
return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled)
|
||||
func getManagedIdentityTokenRetriever(cfg *setting.Cfg, credentials *azcredentials.AzureManagedIdentityCredentials) TokenRetriever {
|
||||
var clientId string
|
||||
if credentials.ClientId != "" {
|
||||
clientId = credentials.ClientId
|
||||
} else {
|
||||
clientId = cfg.Azure.ManagedIdentityClientId
|
||||
}
|
||||
return &managedIdentityTokenRetriever{
|
||||
clientId: clientId,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) getManagedIdentityCredential() TokenCredential {
|
||||
clientId := provider.cfg.Azure.ManagedIdentityClientId
|
||||
|
||||
return &managedIdentityCredential{clientId: clientId}
|
||||
func getClientSecretTokenRetriever(credentials *azcredentials.AzureClientSecretCredentials) TokenRetriever {
|
||||
var authority string
|
||||
if credentials.Authority != "" {
|
||||
authority = credentials.Authority
|
||||
} else {
|
||||
authority = resolveAuthorityForCloud(credentials.AzureCloud)
|
||||
}
|
||||
return &clientSecretTokenRetriever{
|
||||
authority: authority,
|
||||
tenantId: credentials.TenantId,
|
||||
clientId: credentials.ClientId,
|
||||
clientSecret: credentials.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) getClientSecretCredential() TokenCredential {
|
||||
authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"])
|
||||
tenantId := provider.authParams.Params["tenant_id"]
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
clientSecret := provider.authParams.Params["client_secret"]
|
||||
|
||||
return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret}
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string {
|
||||
func resolveAuthorityForCloud(cloudName string) string {
|
||||
// Known Azure clouds
|
||||
switch cloudName {
|
||||
case setting.AzurePublic:
|
||||
@ -93,17 +112,17 @@ func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string
|
||||
return azidentity.AzureGovernment
|
||||
case setting.AzureGermany:
|
||||
return azidentity.AzureGermany
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
// Fallback to direct URL
|
||||
return provider.authParams.Url
|
||||
}
|
||||
|
||||
type managedIdentityCredential struct {
|
||||
type managedIdentityTokenRetriever struct {
|
||||
clientId string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
func (c *managedIdentityTokenRetriever) GetCacheKey() string {
|
||||
clientId := c.clientId
|
||||
if clientId == "" {
|
||||
clientId = "system"
|
||||
@ -111,7 +130,7 @@ func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|msi|%s", clientId)
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) Init() error {
|
||||
func (c *managedIdentityTokenRetriever) Init() error {
|
||||
if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil {
|
||||
return err
|
||||
} else {
|
||||
@ -120,7 +139,7 @@ func (c *managedIdentityCredential) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
func (c *managedIdentityTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -129,7 +148,7 @@ func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes [
|
||||
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
type clientSecretCredential struct {
|
||||
type clientSecretTokenRetriever struct {
|
||||
authority string
|
||||
tenantId string
|
||||
clientId string
|
||||
@ -137,11 +156,11 @@ type clientSecretCredential struct {
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetCacheKey() string {
|
||||
func (c *clientSecretTokenRetriever) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) Init() error {
|
||||
func (c *clientSecretTokenRetriever) Init() error {
|
||||
options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority}
|
||||
if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil {
|
||||
return err
|
||||
@ -151,7 +170,7 @@ func (c *clientSecretCredential) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
func (c *clientSecretTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -4,125 +4,30 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var getAccessTokenFunc func(credential TokenCredential, scopes []string)
|
||||
var getAccessTokenFunc func(credential TokenRetriever, scopes []string)
|
||||
|
||||
type tokenCacheFake struct{}
|
||||
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenRetriever, scopes []string) (string, error) {
|
||||
getAccessTokenFunc(credential, scopes)
|
||||
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if auth type is client secret", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "clientsecret",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be managed identity if datasource not configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "06da9207-bdd9-4558-aee4-377450893cb4",
|
||||
"client_id": "b8c58fe8-1fca-4e30-a0a8-b44d0e5f70d6",
|
||||
"client_secret": "9bcd4434-824f-4887-a8a8-94c287bf0a7b",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when managed identities disabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
|
||||
assert.True(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
|
||||
t.Run("should be client secret if datasource not configured", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
}
|
||||
|
||||
assert.False(t, provider.isManagedIdentityCredential())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
func TestAzureTokenProvider_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "",
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
},
|
||||
scopes := []string{
|
||||
"https://management.azure.com/.default",
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
|
||||
original := azureTokenCache
|
||||
azureTokenCache = &tokenCacheFake{}
|
||||
t.Cleanup(func() { azureTokenCache = original })
|
||||
@ -130,29 +35,31 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should resolve managed identity credential if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
t.Run("should resolve managed identity retriever if auth type is managed identity", func(t *testing.T) {
|
||||
credentials := &azcredentials.AzureManagedIdentityCredentials{}
|
||||
|
||||
provider, err := NewAzureAccessTokenProvider(cfg, credentials)
|
||||
require.NoError(t, err)
|
||||
|
||||
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
|
||||
assert.IsType(t, &managedIdentityTokenRetriever{}, credential)
|
||||
}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.IsType(t, &managedIdentityCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
_, err = provider.GetAccessToken(ctx, scopes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("should resolve client secret credential if auth type is client secret", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "clientsecret",
|
||||
t.Run("should resolve client secret retriever if auth type is client secret", func(t *testing.T) {
|
||||
credentials := &azcredentials.AzureClientSecretCredentials{}
|
||||
|
||||
provider, err := NewAzureAccessTokenProvider(cfg, credentials)
|
||||
require.NoError(t, err)
|
||||
|
||||
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, credential)
|
||||
}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.IsType(t, &clientSecretCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
_, err = provider.GetAccessToken(ctx, scopes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
@ -161,47 +68,61 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should return error if auth type is managed identity", func(t *testing.T) {
|
||||
authParams.Params = map[string]string{
|
||||
"azure_auth_type": "msi",
|
||||
}
|
||||
credentials := &azcredentials.AzureManagedIdentityCredentials{}
|
||||
|
||||
getAccessTokenFunc = func(credential TokenCredential, scopes []string) {
|
||||
assert.Fail(t, "token cache not expected to be called")
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
require.Error(t, err)
|
||||
_, err := NewAzureAccessTokenProvider(cfg, credentials)
|
||||
assert.Error(t, err, "managed identity authentication is not enabled in Grafana config")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
Scopes: []string{
|
||||
"https://management.azure.com/.default",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": "",
|
||||
"azure_cloud": "AzureCloud",
|
||||
"tenant_id": "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
|
||||
"client_id": "1af7c188-e5b6-4f96-81b8-911761bdd459",
|
||||
"client_secret": "0416d95e-8af8-472c-aaa3-15c93c46080a",
|
||||
},
|
||||
credentials := &azcredentials.AzureClientSecretCredentials{
|
||||
AzureCloud: setting.AzurePublic,
|
||||
Authority: "",
|
||||
TenantId: "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
|
||||
ClientId: "1af7c188-e5b6-4f96-81b8-911761bdd459",
|
||||
ClientSecret: "0416d95e-8af8-472c-aaa3-15c93c46080a",
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
t.Run("should return clientSecretTokenRetriever with values", func(t *testing.T) {
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
|
||||
result := provider.getClientSecretCredential()
|
||||
assert.IsType(t, &clientSecretCredential{}, result)
|
||||
|
||||
credential := (result).(*clientSecretCredential)
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://login.microsoftonline.com/", credential.authority)
|
||||
assert.Equal(t, "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", credential.tenantId)
|
||||
assert.Equal(t, "1af7c188-e5b6-4f96-81b8-911761bdd459", credential.clientId)
|
||||
assert.Equal(t, "0416d95e-8af8-472c-aaa3-15c93c46080a", credential.clientSecret)
|
||||
})
|
||||
|
||||
t.Run("authority should selected based on cloud", func(t *testing.T) {
|
||||
originalCloud := credentials.AzureCloud
|
||||
defer func() { credentials.AzureCloud = originalCloud }()
|
||||
|
||||
credentials.AzureCloud = setting.AzureChina
|
||||
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://login.chinacloudapi.cn/", credential.authority)
|
||||
})
|
||||
|
||||
t.Run("explicitly set authority should have priority over cloud", func(t *testing.T) {
|
||||
originalCloud := credentials.AzureCloud
|
||||
defer func() { credentials.AzureCloud = originalCloud }()
|
||||
|
||||
credentials.AzureCloud = setting.AzureChina
|
||||
credentials.Authority = "https://another.com/"
|
||||
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://another.com/", credential.authority)
|
||||
})
|
||||
}
|
||||
|
@ -11,12 +11,14 @@ import (
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
|
||||
"github.com/grafana/grafana/pkg/registry"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -44,20 +46,15 @@ type Service struct {
|
||||
}
|
||||
|
||||
type azureMonitorSettings struct {
|
||||
SubscriptionId string `json:"subscriptionId"`
|
||||
LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"`
|
||||
AppInsightsAppId string `json:"appInsightsAppId"`
|
||||
AzureLogAnalyticsSameAs bool `json:"azureLogAnalyticsSameAs"`
|
||||
ClientId string `json:"clientId"`
|
||||
CloudName string `json:"cloudName"`
|
||||
LogAnalyticsClientId string `json:"logAnalyticsClientId"`
|
||||
LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"`
|
||||
LogAnalyticsSubscriptionId string `json:"logAnalyticsSubscriptionId"`
|
||||
LogAnalyticsTenantId string `json:"logAnalyticsTenantId"`
|
||||
SubscriptionId string `json:"subscriptionId"`
|
||||
TenantId string `json:"tenantId"`
|
||||
AzureAuthType string `json:"azureAuthType,omitempty"`
|
||||
}
|
||||
|
||||
type datasourceInfo struct {
|
||||
Cloud string
|
||||
Credentials azcredentials.AzureCredentials
|
||||
Settings azureMonitorSettings
|
||||
Services map[string]datasourceService
|
||||
Routes map[string]azRoute
|
||||
@ -74,10 +71,15 @@ type datasourceService struct {
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewInstanceSettings() datasource.InstanceFactoryFunc {
|
||||
func NewInstanceSettings(cfg *setting.Cfg) datasource.InstanceFactoryFunc {
|
||||
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
|
||||
jsonData := map[string]interface{}{}
|
||||
err := json.Unmarshal(settings.JSONData, &jsonData)
|
||||
jsonData, err := simplejson.NewJson(settings.JSONData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading settings: %w", err)
|
||||
}
|
||||
|
||||
jsonDataObj := map[string]interface{}{}
|
||||
err = json.Unmarshal(settings.JSONData, &jsonDataObj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading settings: %w", err)
|
||||
}
|
||||
@ -87,20 +89,34 @@ func NewInstanceSettings() datasource.InstanceFactoryFunc {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading settings: %w", err)
|
||||
}
|
||||
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting credentials: %w", err)
|
||||
}
|
||||
|
||||
credentials, err := getAzureCredentials(cfg, jsonData, settings.DecryptedSecureJSONData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting credentials: %w", err)
|
||||
}
|
||||
|
||||
httpCliOpts, err := settings.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting http options: %w", err)
|
||||
}
|
||||
|
||||
model := datasourceInfo{
|
||||
Cloud: cloud,
|
||||
Credentials: credentials,
|
||||
Settings: azMonitorSettings,
|
||||
JSONData: jsonData,
|
||||
JSONData: jsonDataObj,
|
||||
DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
|
||||
DatasourceID: settings.ID,
|
||||
Services: map[string]datasourceService{},
|
||||
Routes: routes[azMonitorSettings.CloudName],
|
||||
Routes: routes[cloud],
|
||||
HTTPCliOpts: httpCliOpts,
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
@ -141,7 +157,7 @@ func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors ma
|
||||
}
|
||||
|
||||
func (s *Service) Init() error {
|
||||
im := datasource.NewInstanceManager(NewInstanceSettings())
|
||||
im := datasource.NewInstanceManager(NewInstanceSettings(s.Cfg))
|
||||
executors := map[string]azDatasourceExecutor{
|
||||
azureMonitor: &AzureMonitorDatasource{},
|
||||
appInsights: &ApplicationInsightsDatasource{},
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -22,14 +23,16 @@ func TestNewInstanceSettings(t *testing.T) {
|
||||
{
|
||||
name: "creates an instance",
|
||||
settings: backend.DataSourceInstanceSettings{
|
||||
JSONData: []byte(`{"cloudName":"azuremonitor"}`),
|
||||
JSONData: []byte(`{"azureAuthType":"msi"}`),
|
||||
DecryptedSecureJSONData: map[string]string{"key": "value"},
|
||||
ID: 40,
|
||||
},
|
||||
expectedModel: datasourceInfo{
|
||||
Settings: azureMonitorSettings{CloudName: "azuremonitor"},
|
||||
Routes: routes["azuremonitor"],
|
||||
JSONData: map[string]interface{}{"cloudName": string("azuremonitor")},
|
||||
Cloud: setting.AzurePublic,
|
||||
Credentials: &azcredentials.AzureManagedIdentityCredentials{},
|
||||
Settings: azureMonitorSettings{},
|
||||
Routes: routes[setting.AzurePublic],
|
||||
JSONData: map[string]interface{}{"azureAuthType": "msi"},
|
||||
DatasourceID: 40,
|
||||
DecryptedSecureJSONData: map[string]string{"key": "value"},
|
||||
},
|
||||
@ -37,9 +40,15 @@ func TestNewInstanceSettings(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzurePublic,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewInstanceSettings()
|
||||
factory := NewInstanceSettings(cfg)
|
||||
instance, err := factory(tt.settings)
|
||||
tt.Err(t, err)
|
||||
if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) {
|
||||
|
@ -1,12 +1,11 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"fmt"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
)
|
||||
|
||||
// Azure cloud names specific to Azure Monitor
|
||||
@ -17,40 +16,107 @@ const (
|
||||
azureMonitorGermany = "germanyazuremonitor"
|
||||
)
|
||||
|
||||
// Azure cloud query types
|
||||
const (
|
||||
azureMonitor = "Azure Monitor"
|
||||
appInsights = "Application Insights"
|
||||
azureLogAnalytics = "Azure Log Analytics"
|
||||
insightsAnalytics = "Insights Analytics"
|
||||
azureResourceGraph = "Azure Resource Graph"
|
||||
)
|
||||
|
||||
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
|
||||
if len(route.Scopes) > 0 {
|
||||
tokenAuth := &plugins.JwtTokenAuth{
|
||||
Url: route.URL,
|
||||
Scopes: route.Scopes,
|
||||
Params: map[string]string{
|
||||
"azure_auth_type": model.Settings.AzureAuthType,
|
||||
"azure_cloud": cfg.Azure.Cloud,
|
||||
"tenant_id": model.Settings.TenantId,
|
||||
"client_id": model.Settings.ClientId,
|
||||
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
|
||||
},
|
||||
}
|
||||
tokenProvider := aztokenprovider.NewAzureAccessTokenProvider(cfg, tokenAuth)
|
||||
return httpclient.NewProvider(httpclient.ProviderOptions{
|
||||
Middlewares: []httpclient.Middleware{
|
||||
aztokenprovider.AuthMiddleware(tokenProvider),
|
||||
},
|
||||
})
|
||||
func getAuthType(cfg *setting.Cfg, jsonData *simplejson.Json) string {
|
||||
if azureAuthType := jsonData.Get("azureAuthType").MustString(); azureAuthType != "" {
|
||||
return azureAuthType
|
||||
} else {
|
||||
return httpclient.NewProvider()
|
||||
tenantId := jsonData.Get("tenantId").MustString()
|
||||
clientId := jsonData.Get("clientId").MustString()
|
||||
|
||||
// If authentication type isn't explicitly specified and datasource has client credentials,
|
||||
// then this is existing datasource which is configured for app registration (client secret)
|
||||
if tenantId != "" && clientId != "" {
|
||||
return azcredentials.AzureAuthClientSecret
|
||||
}
|
||||
|
||||
// For newly created datasource with no configuration, managed identity is the default authentication type
|
||||
// if they are enabled in Grafana config
|
||||
if cfg.Azure.ManagedIdentityEnabled {
|
||||
return azcredentials.AzureAuthManagedIdentity
|
||||
} else {
|
||||
return azcredentials.AzureAuthClientSecret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
|
||||
model.HTTPCliOpts.Headers = route.Headers
|
||||
return httpClientProvider(route, model, cfg).New(model.HTTPCliOpts)
|
||||
func getDefaultAzureCloud(cfg *setting.Cfg) (string, error) {
|
||||
// Allow only known cloud names
|
||||
cloudName := cfg.Azure.Cloud
|
||||
switch cloudName {
|
||||
case setting.AzurePublic:
|
||||
return setting.AzurePublic, nil
|
||||
case setting.AzureChina:
|
||||
return setting.AzureChina, nil
|
||||
case setting.AzureUSGovernment:
|
||||
return setting.AzureUSGovernment, nil
|
||||
case setting.AzureGermany:
|
||||
return setting.AzureGermany, nil
|
||||
case "":
|
||||
// Not set cloud defaults to public
|
||||
return setting.AzurePublic, nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", cloudName)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAzureCloud(cloudName string) (string, error) {
|
||||
switch cloudName {
|
||||
case azureMonitorPublic:
|
||||
return setting.AzurePublic, nil
|
||||
case azureMonitorChina:
|
||||
return setting.AzureChina, nil
|
||||
case azureMonitorUSGovernment:
|
||||
return setting.AzureUSGovernment, nil
|
||||
case azureMonitorGermany:
|
||||
return setting.AzureGermany, nil
|
||||
default:
|
||||
err := fmt.Errorf("the cloud '%s' not supported", cloudName)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureCloud(cfg *setting.Cfg, jsonData *simplejson.Json) (string, error) {
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
switch authType {
|
||||
case azcredentials.AzureAuthManagedIdentity:
|
||||
// In case of managed identity, the cloud is always same as where Grafana is hosted
|
||||
return getDefaultAzureCloud(cfg)
|
||||
case azcredentials.AzureAuthClientSecret:
|
||||
if cloud := jsonData.Get("cloudName").MustString(); cloud != "" {
|
||||
return normalizeAzureCloud(cloud)
|
||||
} else {
|
||||
return getDefaultAzureCloud(cfg)
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("the authentication type '%s' not supported", authType)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureCredentials(cfg *setting.Cfg, jsonData *simplejson.Json, secureJsonData map[string]string) (azcredentials.AzureCredentials, error) {
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
switch authType {
|
||||
case azcredentials.AzureAuthManagedIdentity:
|
||||
credentials := &azcredentials.AzureManagedIdentityCredentials{}
|
||||
return credentials, nil
|
||||
|
||||
case azcredentials.AzureAuthClientSecret:
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
credentials := &azcredentials.AzureClientSecretCredentials{
|
||||
AzureCloud: cloud,
|
||||
TenantId: jsonData.Get("tenantId").MustString(),
|
||||
ClientId: jsonData.Get("clientId").MustString(),
|
||||
ClientSecret: secureJsonData["clientSecret"],
|
||||
}
|
||||
return credentials, nil
|
||||
|
||||
default:
|
||||
err := fmt.Errorf("the authentication type '%s' not supported", authType)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -3,50 +3,195 @@ package azuremonitor
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_httpCliProvider(t *testing.T) {
|
||||
func TestCredentials_getAuthType(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
model := datasourceInfo{
|
||||
Settings: azureMonitorSettings{},
|
||||
DecryptedSecureJSONData: map[string]string{"clientSecret": "content"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
route azRoute
|
||||
expectedMiddlewares int
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates an HTTP client with a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{"http://route/.default"},
|
||||
},
|
||||
expectedMiddlewares: 1,
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "creates an HTTP client without a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{},
|
||||
},
|
||||
// httpclient.NewProvider returns a client with 2 middlewares by default
|
||||
expectedMiddlewares: 2,
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := httpClientProvider(tt.route, model, cfg)
|
||||
// Cannot test that the cli middleware works properly since the azcore sdk
|
||||
// rejects the TLS certs (if provided)
|
||||
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
|
||||
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
|
||||
}
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should be client secret if auth type is set to client secret", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthClientSecret,
|
||||
})
|
||||
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("should be managed identity if datasource not configured", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": "",
|
||||
})
|
||||
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType)
|
||||
})
|
||||
|
||||
t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": "",
|
||||
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
|
||||
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
|
||||
})
|
||||
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when managed identities disabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should be managed identity if auth type is set to managed identity", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
|
||||
})
|
||||
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType)
|
||||
})
|
||||
|
||||
t.Run("should be client secret if datasource not configured", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": "",
|
||||
})
|
||||
|
||||
authType := getAuthType(cfg, jsonData)
|
||||
|
||||
assert.Equal(t, azcredentials.AzureAuthClientSecret, authType)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentials_getAzureCloud(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzureChina,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("when auth type is managed identity", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
|
||||
"cloudName": azureMonitorGermany,
|
||||
})
|
||||
|
||||
t.Run("should be from server configuration regardless of datasource value", func(t *testing.T) {
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, setting.AzureChina, cloud)
|
||||
})
|
||||
|
||||
t.Run("should be public if not set in server configuration", func(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: "",
|
||||
},
|
||||
}
|
||||
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, setting.AzurePublic, cloud)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when auth type is client secret", func(t *testing.T) {
|
||||
t.Run("should be from datasource value normalized to known cloud name", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthClientSecret,
|
||||
"cloudName": azureMonitorGermany,
|
||||
})
|
||||
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, setting.AzureGermany, cloud)
|
||||
})
|
||||
|
||||
t.Run("should be from server configuration if not set in datasource", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthClientSecret,
|
||||
"cloudName": "",
|
||||
})
|
||||
|
||||
cloud, err := getAzureCloud(cfg, jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, setting.AzureChina, cloud)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentials_getAzureCredentials(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzureChina,
|
||||
},
|
||||
}
|
||||
|
||||
secureJsonData := map[string]string{
|
||||
"clientSecret": "59e3498f-eb12-4943-b8f0-a5aa42640058",
|
||||
}
|
||||
|
||||
t.Run("when auth type is managed identity", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthManagedIdentity,
|
||||
"cloudName": azureMonitorGermany,
|
||||
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
|
||||
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
|
||||
})
|
||||
|
||||
t.Run("should return managed identity credentials", func(t *testing.T) {
|
||||
credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &azcredentials.AzureManagedIdentityCredentials{}, credentials)
|
||||
msiCredentials := credentials.(*azcredentials.AzureManagedIdentityCredentials)
|
||||
|
||||
// Azure Monitor datasource doesn't support user-assigned managed identities (ClientId is always empty)
|
||||
assert.Equal(t, "", msiCredentials.ClientId)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when auth type is client secret", func(t *testing.T) {
|
||||
jsonData := simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuthType": azcredentials.AzureAuthClientSecret,
|
||||
"cloudName": azureMonitorGermany,
|
||||
"tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c",
|
||||
"clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6",
|
||||
})
|
||||
|
||||
t.Run("should return client secret credentials", func(t *testing.T) {
|
||||
cfg := &setting.Cfg{
|
||||
Azure: setting.AzureSettings{
|
||||
Cloud: setting.AzureChina,
|
||||
},
|
||||
}
|
||||
|
||||
credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &azcredentials.AzureClientSecretCredentials{}, credentials)
|
||||
clientSecretCredentials := credentials.(*azcredentials.AzureClientSecretCredentials)
|
||||
|
||||
assert.Equal(t, setting.AzureGermany, clientSecretCredentials.AzureCloud)
|
||||
assert.Equal(t, "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", clientSecretCredentials.TenantId)
|
||||
assert.Equal(t, "849ccbb0-92eb-4226-b228-ef391abd8fe6", clientSecretCredentials.ClientId)
|
||||
assert.Equal(t, "59e3498f-eb12-4943-b8f0-a5aa42640058", clientSecretCredentials.ClientSecret)
|
||||
|
||||
// Azure Monitor datasource doesn't support custom IdP authorities (Authority is always empty)
|
||||
assert.Equal(t, "", clientSecretCredentials.Authority)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
41
pkg/tsdb/azuremonitor/httpclient.go
Normal file
41
pkg/tsdb/azuremonitor/httpclient.go
Normal file
@ -0,0 +1,41 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*httpclient.Provider, error) {
|
||||
var clientProvider *httpclient.Provider
|
||||
|
||||
if len(route.Scopes) > 0 {
|
||||
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, model.Credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientProvider = httpclient.NewProvider(httpclient.ProviderOptions{
|
||||
Middlewares: []httpclient.Middleware{
|
||||
aztokenprovider.AuthMiddleware(tokenProvider, route.Scopes),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
clientProvider = httpclient.NewProvider()
|
||||
}
|
||||
|
||||
return clientProvider, nil
|
||||
}
|
||||
|
||||
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
|
||||
model.HTTPCliOpts.Headers = route.Headers
|
||||
|
||||
clientProvider, err := httpClientProvider(route, model, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientProvider.New(model.HTTPCliOpts)
|
||||
}
|
54
pkg/tsdb/azuremonitor/httpclient_test.go
Normal file
54
pkg/tsdb/azuremonitor/httpclient_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_httpCliProvider(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
model := datasourceInfo{
|
||||
Credentials: &azcredentials.AzureClientSecretCredentials{},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
route azRoute
|
||||
expectedMiddlewares int
|
||||
Err require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "creates an HTTP client with a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{"http://route/.default"},
|
||||
},
|
||||
expectedMiddlewares: 1,
|
||||
Err: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "creates an HTTP client without a middleware",
|
||||
route: azRoute{
|
||||
URL: "http://route",
|
||||
Scopes: []string{},
|
||||
},
|
||||
// httpclient.NewProvider returns a client with 2 middlewares by default
|
||||
expectedMiddlewares: 2,
|
||||
Err: require.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli, err := httpClientProvider(tt.route, model, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cannot test that the cli middleware works properly since the azcore sdk
|
||||
// rejects the TLS certs (if provided)
|
||||
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
|
||||
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,5 +1,16 @@
|
||||
package azuremonitor
|
||||
|
||||
import "github.com/grafana/grafana/pkg/setting"
|
||||
|
||||
// Azure cloud query types
|
||||
const (
|
||||
azureMonitor = "Azure Monitor"
|
||||
appInsights = "Application Insights"
|
||||
azureLogAnalytics = "Azure Log Analytics"
|
||||
insightsAnalytics = "Insights Analytics"
|
||||
azureResourceGraph = "Azure Resource Graph"
|
||||
)
|
||||
|
||||
type azRoute struct {
|
||||
URL string
|
||||
Scopes []string
|
||||
@ -64,22 +75,22 @@ var (
|
||||
// The different Azure routes are identified by its cloud (e.g. public or gov)
|
||||
// and the service to query (e.g. Azure Monitor or Azure Log Analytics)
|
||||
routes = map[string]map[string]azRoute{
|
||||
azureMonitorPublic: {
|
||||
setting.AzurePublic: {
|
||||
azureMonitor: azManagement,
|
||||
azureLogAnalytics: azLogAnalytics,
|
||||
azureResourceGraph: azManagement,
|
||||
appInsights: azAppInsights,
|
||||
insightsAnalytics: azAppInsights,
|
||||
},
|
||||
azureMonitorUSGovernment: {
|
||||
setting.AzureUSGovernment: {
|
||||
azureMonitor: azUSGovManagement,
|
||||
azureLogAnalytics: azUSGovLogAnalytics,
|
||||
azureResourceGraph: azUSGovManagement,
|
||||
},
|
||||
azureMonitorGermany: {
|
||||
setting.AzureGermany: {
|
||||
azureMonitor: azGermanyManagement,
|
||||
},
|
||||
azureMonitorChina: {
|
||||
setting.AzureChina: {
|
||||
azureMonitor: azChinaManagement,
|
||||
azureLogAnalytics: azChinaLogAnalytics,
|
||||
azureResourceGraph: azChinaManagement,
|
||||
|
Loading…
Reference in New Issue
Block a user