From 89ba6073824a2df27dcaffedc2dcbe85c99f953a Mon Sep 17 00:00:00 2001 From: Sergey Kostrukov Date: Mon, 5 Jul 2021 03:20:12 -0700 Subject: [PATCH] AzureMonitor: strongly-typed AzureCredentials and correct resolution of auth type and cloud (#36284) --- pkg/api/pluginproxy/ds_auth_provider.go | 3 +- pkg/api/pluginproxy/token_provider_azure.go | 42 +++- .../azuremonitor/azcredentials/credentials.go | 30 +++ .../aztokenprovider/middleware.go | 4 +- .../aztokenprovider/token_cache.go | 32 +-- .../aztokenprovider/token_cache_test.go | 160 ++++++------- .../aztokenprovider/token_provider.go | 129 +++++----- .../aztokenprovider/token_provider_test.go | 211 +++++----------- pkg/tsdb/azuremonitor/azuremonitor.go | 46 ++-- pkg/tsdb/azuremonitor/azuremonitor_test.go | 19 +- pkg/tsdb/azuremonitor/credentials.go | 138 ++++++++--- pkg/tsdb/azuremonitor/credentials_test.go | 225 ++++++++++++++---- pkg/tsdb/azuremonitor/httpclient.go | 41 ++++ pkg/tsdb/azuremonitor/httpclient_test.go | 54 +++++ pkg/tsdb/azuremonitor/routes.go | 19 +- 15 files changed, 749 insertions(+), 404 deletions(-) create mode 100644 pkg/tsdb/azuremonitor/azcredentials/credentials.go create mode 100644 pkg/tsdb/azuremonitor/httpclient.go create mode 100644 pkg/tsdb/azuremonitor/httpclient_test.go diff --git a/pkg/api/pluginproxy/ds_auth_provider.go b/pkg/api/pluginproxy/ds_auth_provider.go index 6a63c2e44eb..e2a11ac4475 100644 --- a/pkg/api/pluginproxy/ds_auth_provider.go +++ b/pkg/api/pluginproxy/ds_auth_provider.go @@ -92,8 +92,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour if tokenAuth == nil { return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType) } - provider := newAzureAccessTokenProvider(ctx, cfg, tokenAuth) - return provider, nil + return newAzureAccessTokenProvider(ctx, cfg, tokenAuth) case "gce": if jwtTokenAuth == nil { diff --git a/pkg/api/pluginproxy/token_provider_azure.go b/pkg/api/pluginproxy/token_provider_azure.go index 829e3875a6c..ff32df8f6a7 100644 --- a/pkg/api/pluginproxy/token_provider_azure.go +++ b/pkg/api/pluginproxy/token_provider_azure.go @@ -2,24 +2,58 @@ package pluginproxy import ( "context" + "strings" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" "github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider" ) type azureAccessTokenProvider struct { ctx context.Context tokenProvider aztokenprovider.AzureTokenProvider + scopes []string } -func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider { +func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) (*azureAccessTokenProvider, error) { + credentials := getAzureCredentials(cfg, authParams) + tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, credentials) + if err != nil { + return nil, err + } return &azureAccessTokenProvider{ ctx: ctx, - tokenProvider: aztokenprovider.NewAzureAccessTokenProvider(cfg, authParams), - } + tokenProvider: tokenProvider, + scopes: authParams.Scopes, + }, nil } func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) { - return provider.tokenProvider.GetAccessToken(provider.ctx) + return provider.tokenProvider.GetAccessToken(provider.ctx, provider.scopes) +} + +func getAzureCredentials(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) azcredentials.AzureCredentials { + authType := strings.ToLower(authParams.Params["azure_auth_type"]) + clientId := authParams.Params["client_id"] + + // Type of authentication being determined by the following logic: + // * If authType is set to 'msi' then user explicitly selected the managed identity authentication + // * If authType isn't set but other fields are configured then it's a datasource which was configured + // before managed identities where introduced, therefore use client secret authentication + // * If authType and other fields aren't set then it means the datasource never been configured + // and managed identity is the default authentication choice as long as managed identities are enabled + isManagedIdentity := authType == "msi" || (authType == "" && clientId == "" && cfg.Azure.ManagedIdentityEnabled) + + if isManagedIdentity { + return &azcredentials.AzureManagedIdentityCredentials{} + } else { + return &azcredentials.AzureClientSecretCredentials{ + AzureCloud: authParams.Params["azure_cloud"], + Authority: authParams.Url, + TenantId: authParams.Params["tenant_id"], + ClientId: authParams.Params["client_id"], + ClientSecret: authParams.Params["client_secret"], + } + } } diff --git a/pkg/tsdb/azuremonitor/azcredentials/credentials.go b/pkg/tsdb/azuremonitor/azcredentials/credentials.go new file mode 100644 index 00000000000..49311e5c11a --- /dev/null +++ b/pkg/tsdb/azuremonitor/azcredentials/credentials.go @@ -0,0 +1,30 @@ +package azcredentials + +const ( + AzureAuthManagedIdentity = "msi" + AzureAuthClientSecret = "clientsecret" +) + +type AzureCredentials interface { + AzureAuthType() string +} + +type AzureManagedIdentityCredentials struct { + ClientId string +} + +type AzureClientSecretCredentials struct { + AzureCloud string + Authority string + TenantId string + ClientId string + ClientSecret string +} + +func (credentials *AzureManagedIdentityCredentials) AzureAuthType() string { + return AzureAuthManagedIdentity +} + +func (credentials *AzureClientSecretCredentials) AzureAuthType() string { + return AzureAuthClientSecret +} diff --git a/pkg/tsdb/azuremonitor/aztokenprovider/middleware.go b/pkg/tsdb/azuremonitor/aztokenprovider/middleware.go index c6f06be22fa..d80f75002f7 100644 --- a/pkg/tsdb/azuremonitor/aztokenprovider/middleware.go +++ b/pkg/tsdb/azuremonitor/aztokenprovider/middleware.go @@ -9,10 +9,10 @@ import ( const authenticationMiddlewareName = "AzureAuthentication" -func AuthMiddleware(tokenProvider AzureTokenProvider) httpclient.Middleware { +func AuthMiddleware(tokenProvider AzureTokenProvider, scopes []string) httpclient.Middleware { return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - token, err := tokenProvider.GetAccessToken(req.Context()) + token, err := tokenProvider.GetAccessToken(req.Context(), scopes) if err != nil { return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err) } diff --git a/pkg/tsdb/azuremonitor/aztokenprovider/token_cache.go b/pkg/tsdb/azuremonitor/aztokenprovider/token_cache.go index 9de9cb3ced4..3c375979e4d 100644 --- a/pkg/tsdb/azuremonitor/aztokenprovider/token_cache.go +++ b/pkg/tsdb/azuremonitor/aztokenprovider/token_cache.go @@ -19,14 +19,14 @@ type AccessToken struct { ExpiresOn time.Time } -type TokenCredential interface { +type TokenRetriever interface { GetCacheKey() string Init() error GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) } type ConcurrentTokenCache interface { - GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) + GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error) } func NewConcurrentTokenCache() ConcurrentTokenCache { @@ -37,7 +37,7 @@ type tokenCacheImpl struct { cache sync.Map // of *credentialCacheEntry } type credentialCacheEntry struct { - credential TokenCredential + retriever TokenRetriever credInit uint32 credMutex sync.Mutex @@ -45,19 +45,19 @@ type credentialCacheEntry struct { } type scopesCacheEntry struct { - credential TokenCredential - scopes []string + retriever TokenRetriever + scopes []string cond *sync.Cond refreshing bool accessToken *AccessToken } -func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) { - return c.getEntryFor(credential).getAccessToken(ctx, scopes) +func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error) { + return c.getEntryFor(tokenRetriever).getAccessToken(ctx, scopes) } -func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry { +func (c *tokenCacheImpl) getEntryFor(credential TokenRetriever) *credentialCacheEntry { var entry interface{} var ok bool @@ -65,7 +65,7 @@ func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCach if entry, ok = c.cache.Load(key); !ok { entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{ - credential: credential, + retriever: credential, }) } @@ -87,8 +87,8 @@ func (c *credentialCacheEntry) ensureInitialized() error { defer c.credMutex.Unlock() if c.credInit == 0 { - // Initialize credential - err := c.credential.Init() + // Initialize retriever + err := c.retriever.Init() if err != nil { return err } @@ -108,9 +108,9 @@ func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry { if entry, ok = c.cache.Load(key); !ok { entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{ - credential: c.credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: c.retriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), }) } @@ -155,7 +155,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) { func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) { var accessToken *AccessToken - // Safeguarding from panic caused by credential implementation + // Safeguarding from panic caused by retriever implementation defer func() { c.cond.L.Lock() @@ -169,7 +169,7 @@ func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken c.cond.L.Unlock() }() - token, err := c.credential.GetAccessToken(ctx, c.scopes) + token, err := c.retriever.GetAccessToken(ctx, c.scopes) if err != nil { return nil, err } diff --git a/pkg/tsdb/azuremonitor/aztokenprovider/token_cache_test.go b/pkg/tsdb/azuremonitor/aztokenprovider/token_cache_test.go index ee812d88864..1d26d004c3f 100644 --- a/pkg/tsdb/azuremonitor/aztokenprovider/token_cache_test.go +++ b/pkg/tsdb/azuremonitor/aztokenprovider/token_cache_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -type fakeCredential struct { +type fakeRetriever struct { key string initCalledTimes int calledTimes int @@ -20,16 +20,16 @@ type fakeCredential struct { getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error) } -func (c *fakeCredential) GetCacheKey() string { +func (c *fakeRetriever) GetCacheKey() string { return c.key } -func (c *fakeCredential) Reset() { +func (c *fakeRetriever) Reset() { c.initCalledTimes = 0 c.calledTimes = 0 } -func (c *fakeCredential) Init() error { +func (c *fakeRetriever) Init() error { c.initCalledTimes = c.initCalledTimes + 1 if c.initFunc != nil { return c.initFunc() @@ -37,7 +37,7 @@ func (c *fakeCredential) Init() error { return nil } -func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { +func (c *fakeRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { c.calledTimes = c.calledTimes + 1 if c.getAccessTokenFunc != nil { return c.getAccessTokenFunc(ctx, scopes) @@ -52,15 +52,15 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { scopes1 := []string{"Scope1"} scopes2 := []string{"Scope2"} - t.Run("should request access token from credential", func(t *testing.T) { + t.Run("should request access token from retriever", func(t *testing.T) { cache := NewConcurrentTokenCache() - credential := &fakeCredential{key: "credential-1"} + tokenRetriever := &fakeRetriever{key: "retriever"} - token, err := cache.GetAccessToken(ctx, credential, scopes1) + token, err := cache.GetAccessToken(ctx, tokenRetriever, scopes1) require.NoError(t, err) - assert.Equal(t, "credential-1-token-1", token) + assert.Equal(t, "retriever-token-1", token) - assert.Equal(t, 1, credential.calledTimes) + assert.Equal(t, 1, tokenRetriever.calledTimes) }) t.Run("should return cached token for same scopes", func(t *testing.T) { @@ -68,7 +68,7 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { var err error cache := NewConcurrentTokenCache() - credential := &fakeCredential{key: "credential-1"} + credential := &fakeRetriever{key: "credential-1"} token1, err = cache.GetAccessToken(ctx, credential, scopes1) require.NoError(t, err) @@ -94,8 +94,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { var err error cache := NewConcurrentTokenCache() - credential1 := &fakeCredential{key: "credential-1"} - credential2 := &fakeCredential{key: "credential-2"} + credential1 := &fakeRetriever{key: "credential-1"} + credential2 := &fakeRetriever{key: "credential-2"} token1, err = cache.GetAccessToken(ctx, credential1, scopes1) require.NoError(t, err) @@ -119,8 +119,8 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { } func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { - t.Run("when credential init returns error", func(t *testing.T) { - credential := &fakeCredential{ + t.Run("when retriever init returns error", func(t *testing.T) { + tokenRetriever := &fakeRetriever{ initFunc: func() error { return errors.New("unable to initialize") }, @@ -128,7 +128,7 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { t.Run("should return error", func(t *testing.T) { cacheEntry := &credentialCacheEntry{ - credential: credential, + retriever: tokenRetriever, } err := cacheEntry.ensureInitialized() @@ -137,10 +137,10 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { }) t.Run("should call init again each time and return error", func(t *testing.T) { - credential.Reset() + tokenRetriever.Reset() cacheEntry := &credentialCacheEntry{ - credential: credential, + retriever: tokenRetriever, } var err error @@ -153,13 +153,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { err = cacheEntry.ensureInitialized() assert.Error(t, err) - assert.Equal(t, 3, credential.initCalledTimes) + assert.Equal(t, 3, tokenRetriever.initCalledTimes) }) }) - t.Run("when credential init returns error only once", func(t *testing.T) { + t.Run("when retriever init returns error only once", func(t *testing.T) { var times = 0 - credential := &fakeCredential{ + tokenRetriever := &fakeRetriever{ initFunc: func() error { times = times + 1 if times == 1 { @@ -169,9 +169,9 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { }, } - t.Run("should call credential init again only while it returns error", func(t *testing.T) { + t.Run("should call retriever init again only while it returns error", func(t *testing.T) { cacheEntry := &credentialCacheEntry{ - credential: credential, + retriever: tokenRetriever, } var err error @@ -184,52 +184,52 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { err = cacheEntry.ensureInitialized() assert.NoError(t, err) - assert.Equal(t, 2, credential.initCalledTimes) + assert.Equal(t, 2, tokenRetriever.initCalledTimes) }) }) - t.Run("when credential init panics", func(t *testing.T) { - credential := &fakeCredential{ + t.Run("when retriever init panics", func(t *testing.T) { + tokenRetriever := &fakeRetriever{ initFunc: func() error { panic(errors.New("unable to initialize")) }, } - t.Run("should call credential init again each time", func(t *testing.T) { - credential.Reset() + t.Run("should call retriever init again each time", func(t *testing.T) { + tokenRetriever.Reset() cacheEntry := &credentialCacheEntry{ - credential: credential, + retriever: tokenRetriever, } func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _ = cacheEntry.ensureInitialized() }() func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _ = cacheEntry.ensureInitialized() }() func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _ = cacheEntry.ensureInitialized() }() - assert.Equal(t, 3, credential.initCalledTimes) + assert.Equal(t, 3, tokenRetriever.initCalledTimes) }) }) - t.Run("when credential init panics only once", func(t *testing.T) { + t.Run("when retriever init panics only once", func(t *testing.T) { var times = 0 - credential := &fakeCredential{ + tokenRetriever := &fakeRetriever{ initFunc: func() error { times = times + 1 if times == 1 { @@ -239,23 +239,23 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { }, } - t.Run("should call credential init again only while it panics", func(t *testing.T) { + t.Run("should call retriever init again only while it panics", func(t *testing.T) { cacheEntry := &credentialCacheEntry{ - credential: credential, + retriever: tokenRetriever, } var err error func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _ = cacheEntry.ensureInitialized() }() func() { defer func() { - assert.Nil(t, recover(), "credential not expected to panic") + assert.Nil(t, recover(), "retriever not expected to panic") }() err = cacheEntry.ensureInitialized() assert.NoError(t, err) @@ -263,13 +263,13 @@ func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { func() { defer func() { - assert.Nil(t, recover(), "credential not expected to panic") + assert.Nil(t, recover(), "retriever not expected to panic") }() err = cacheEntry.ensureInitialized() assert.NoError(t, err) }() - assert.Equal(t, 2, credential.initCalledTimes) + assert.Equal(t, 2, tokenRetriever.initCalledTimes) }) }) } @@ -279,8 +279,8 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { scopes := []string{"Scope1"} - t.Run("when credential getAccessToken returns error", func(t *testing.T) { - credential := &fakeCredential{ + t.Run("when retriever getAccessToken returns error", func(t *testing.T) { + tokenRetriever := &fakeRetriever{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)} return invalidToken, errors.New("unable to get access token") @@ -289,9 +289,9 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { t.Run("should return error", func(t *testing.T) { cacheEntry := &scopesCacheEntry{ - credential: credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: tokenRetriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), } accessToken, err := cacheEntry.getAccessToken(ctx) @@ -300,13 +300,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { assert.Equal(t, "", accessToken) }) - t.Run("should call credential again each time and return error", func(t *testing.T) { - credential.Reset() + t.Run("should call retriever again each time and return error", func(t *testing.T) { + tokenRetriever.Reset() cacheEntry := &scopesCacheEntry{ - credential: credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: tokenRetriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), } var err error @@ -319,13 +319,13 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { _, err = cacheEntry.getAccessToken(ctx) assert.Error(t, err) - assert.Equal(t, 3, credential.calledTimes) + assert.Equal(t, 3, tokenRetriever.calledTimes) }) }) - t.Run("when credential getAccessToken returns error only once", func(t *testing.T) { + t.Run("when retriever getAccessToken returns error only once", func(t *testing.T) { var times = 0 - credential := &fakeCredential{ + retriever := &fakeRetriever{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { times = times + 1 if times == 1 { @@ -337,11 +337,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }, } - t.Run("should call credential again only while it returns error", func(t *testing.T) { + t.Run("should call retriever again only while it returns error", func(t *testing.T) { cacheEntry := &scopesCacheEntry{ - credential: credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: retriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), } var accessToken string @@ -358,54 +358,54 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "token-2", accessToken) - assert.Equal(t, 2, credential.calledTimes) + assert.Equal(t, 2, retriever.calledTimes) }) }) - t.Run("when credential getAccessToken panics", func(t *testing.T) { - credential := &fakeCredential{ + t.Run("when retriever getAccessToken panics", func(t *testing.T) { + tokenRetriever := &fakeRetriever{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { panic(errors.New("unable to get access token")) }, } - t.Run("should call credential again each time", func(t *testing.T) { - credential.Reset() + t.Run("should call retriever again each time", func(t *testing.T) { + tokenRetriever.Reset() cacheEntry := &scopesCacheEntry{ - credential: credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: tokenRetriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), } func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _, _ = cacheEntry.getAccessToken(ctx) }() func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _, _ = cacheEntry.getAccessToken(ctx) }() func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _, _ = cacheEntry.getAccessToken(ctx) }() - assert.Equal(t, 3, credential.calledTimes) + assert.Equal(t, 3, tokenRetriever.calledTimes) }) }) - t.Run("when credential getAccessToken panics only once", func(t *testing.T) { + t.Run("when retriever getAccessToken panics only once", func(t *testing.T) { var times = 0 - credential := &fakeCredential{ + tokenRetriever := &fakeRetriever{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { times = times + 1 if times == 1 { @@ -416,11 +416,11 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }, } - t.Run("should call credential again only while it panics", func(t *testing.T) { + t.Run("should call retriever again only while it panics", func(t *testing.T) { cacheEntry := &scopesCacheEntry{ - credential: credential, - scopes: scopes, - cond: sync.NewCond(&sync.Mutex{}), + retriever: tokenRetriever, + scopes: scopes, + cond: sync.NewCond(&sync.Mutex{}), } var accessToken string @@ -428,14 +428,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { func() { defer func() { - assert.NotNil(t, recover(), "credential expected to panic") + assert.NotNil(t, recover(), "retriever expected to panic") }() _, _ = cacheEntry.getAccessToken(ctx) }() func() { defer func() { - assert.Nil(t, recover(), "credential not expected to panic") + assert.Nil(t, recover(), "retriever not expected to panic") }() accessToken, err = cacheEntry.getAccessToken(ctx) assert.NoError(t, err) @@ -444,14 +444,14 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { func() { defer func() { - assert.Nil(t, recover(), "credential not expected to panic") + assert.Nil(t, recover(), "retriever not expected to panic") }() accessToken, err = cacheEntry.getAccessToken(ctx) assert.NoError(t, err) assert.Equal(t, "token-2", accessToken) }() - assert.Equal(t, 2, credential.calledTimes) + assert.Equal(t, 2, tokenRetriever.calledTimes) }) }) } diff --git a/pkg/tsdb/azuremonitor/aztokenprovider/token_provider.go b/pkg/tsdb/azuremonitor/aztokenprovider/token_provider.go index f1725ba2cf6..c5d9a289aec 100644 --- a/pkg/tsdb/azuremonitor/aztokenprovider/token_provider.go +++ b/pkg/tsdb/azuremonitor/aztokenprovider/token_provider.go @@ -4,12 +4,11 @@ import ( "context" "crypto/sha256" "fmt" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" ) var ( @@ -17,72 +16,92 @@ var ( ) type AzureTokenProvider interface { - GetAccessToken(ctx context.Context) (string, error) + GetAccessToken(ctx context.Context, scopes []string) (string, error) } type tokenProviderImpl struct { - cfg *setting.Cfg - authParams *plugins.JwtTokenAuth + tokenRetriever TokenRetriever } -func NewAzureAccessTokenProvider(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *tokenProviderImpl { - return &tokenProviderImpl{ - cfg: cfg, - authParams: authParams, +func NewAzureAccessTokenProvider(cfg *setting.Cfg, credentials azcredentials.AzureCredentials) (AzureTokenProvider, error) { + if cfg == nil { + err := fmt.Errorf("parameter 'cfg' cannot be nil") + return nil, err + } + if credentials == nil { + err := fmt.Errorf("parameter 'credentials' cannot be nil") + return nil, err } -} -func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context) (string, error) { - var credential TokenCredential + var tokenRetriever TokenRetriever - if provider.isManagedIdentityCredential() { - if !provider.cfg.Azure.ManagedIdentityEnabled { + switch c := credentials.(type) { + case *azcredentials.AzureManagedIdentityCredentials: + if !cfg.Azure.ManagedIdentityEnabled { err := fmt.Errorf("managed identity authentication is not enabled in Grafana config") - return "", err + return nil, err } else { - credential = provider.getManagedIdentityCredential() + tokenRetriever = getManagedIdentityTokenRetriever(cfg, c) } - } else { - credential = provider.getClientSecretCredential() + case *azcredentials.AzureClientSecretCredentials: + tokenRetriever = getClientSecretTokenRetriever(c) + default: + err := fmt.Errorf("credentials of type '%s' not supported by authentication provider", c.AzureAuthType()) + return nil, err } - accessToken, err := azureTokenCache.GetAccessToken(ctx, credential, provider.authParams.Scopes) - if err != nil { + tokenProvider := &tokenProviderImpl{ + tokenRetriever: tokenRetriever, + } + + return tokenProvider, nil +} + +func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context, scopes []string) (string, error) { + if ctx == nil { + err := fmt.Errorf("parameter 'ctx' cannot be nil") + return "", err + } + if scopes == nil { + err := fmt.Errorf("parameter 'scopes' cannot be nil") return "", err } + accessToken, err := azureTokenCache.GetAccessToken(ctx, provider.tokenRetriever, scopes) + if err != nil { + return "", err + } return accessToken, nil } -func (provider *tokenProviderImpl) isManagedIdentityCredential() bool { - authType := strings.ToLower(provider.authParams.Params["azure_auth_type"]) - clientId := provider.authParams.Params["client_id"] - - // Type of authentication being determined by the following logic: - // * If authType is set to 'msi' then user explicitly selected the managed identity authentication - // * If authType isn't set but other fields are configured then it's a datasource which was configured - // before managed identities where introduced, therefore use client secret authentication - // * If authType and other fields aren't set then it means the datasource never been configured - // and managed identity is the default authentication choice as long as managed identities are enabled - return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled) +func getManagedIdentityTokenRetriever(cfg *setting.Cfg, credentials *azcredentials.AzureManagedIdentityCredentials) TokenRetriever { + var clientId string + if credentials.ClientId != "" { + clientId = credentials.ClientId + } else { + clientId = cfg.Azure.ManagedIdentityClientId + } + return &managedIdentityTokenRetriever{ + clientId: clientId, + } } -func (provider *tokenProviderImpl) getManagedIdentityCredential() TokenCredential { - clientId := provider.cfg.Azure.ManagedIdentityClientId - - return &managedIdentityCredential{clientId: clientId} +func getClientSecretTokenRetriever(credentials *azcredentials.AzureClientSecretCredentials) TokenRetriever { + var authority string + if credentials.Authority != "" { + authority = credentials.Authority + } else { + authority = resolveAuthorityForCloud(credentials.AzureCloud) + } + return &clientSecretTokenRetriever{ + authority: authority, + tenantId: credentials.TenantId, + clientId: credentials.ClientId, + clientSecret: credentials.ClientSecret, + } } -func (provider *tokenProviderImpl) getClientSecretCredential() TokenCredential { - authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"]) - tenantId := provider.authParams.Params["tenant_id"] - clientId := provider.authParams.Params["client_id"] - clientSecret := provider.authParams.Params["client_secret"] - - return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret} -} - -func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string { +func resolveAuthorityForCloud(cloudName string) string { // Known Azure clouds switch cloudName { case setting.AzurePublic: @@ -93,17 +112,17 @@ func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string return azidentity.AzureGovernment case setting.AzureGermany: return azidentity.AzureGermany + default: + return "" } - // Fallback to direct URL - return provider.authParams.Url } -type managedIdentityCredential struct { +type managedIdentityTokenRetriever struct { clientId string credential azcore.TokenCredential } -func (c *managedIdentityCredential) GetCacheKey() string { +func (c *managedIdentityTokenRetriever) GetCacheKey() string { clientId := c.clientId if clientId == "" { clientId = "system" @@ -111,7 +130,7 @@ func (c *managedIdentityCredential) GetCacheKey() string { return fmt.Sprintf("azure|msi|%s", clientId) } -func (c *managedIdentityCredential) Init() error { +func (c *managedIdentityTokenRetriever) Init() error { if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil { return err } else { @@ -120,7 +139,7 @@ func (c *managedIdentityCredential) Init() error { } } -func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { +func (c *managedIdentityTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) if err != nil { return nil, err @@ -129,7 +148,7 @@ func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes [ return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil } -type clientSecretCredential struct { +type clientSecretTokenRetriever struct { authority string tenantId string clientId string @@ -137,11 +156,11 @@ type clientSecretCredential struct { credential azcore.TokenCredential } -func (c *clientSecretCredential) GetCacheKey() string { +func (c *clientSecretTokenRetriever) GetCacheKey() string { return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret)) } -func (c *clientSecretCredential) Init() error { +func (c *clientSecretTokenRetriever) Init() error { options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority} if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil { return err @@ -151,7 +170,7 @@ func (c *clientSecretCredential) Init() error { } } -func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { +func (c *clientSecretTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) if err != nil { return nil, err diff --git a/pkg/tsdb/azuremonitor/aztokenprovider/token_provider_test.go b/pkg/tsdb/azuremonitor/aztokenprovider/token_provider_test.go index 870d6c759fb..5b7ec3bcb29 100644 --- a/pkg/tsdb/azuremonitor/aztokenprovider/token_provider_test.go +++ b/pkg/tsdb/azuremonitor/aztokenprovider/token_provider_test.go @@ -4,125 +4,30 @@ import ( "context" "testing" - "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var getAccessTokenFunc func(credential TokenCredential, scopes []string) +var getAccessTokenFunc func(credential TokenRetriever, scopes []string) type tokenCacheFake struct{} -func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) { +func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenRetriever, scopes []string) (string, error) { getAccessTokenFunc(credential, scopes) return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil } -func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) { - cfg := &setting.Cfg{} - - authParams := &plugins.JwtTokenAuth{ - Scopes: []string{ - "https://management.azure.com/.default", - }, - Params: map[string]string{ - "azure_auth_type": "", - "azure_cloud": "AzureCloud", - "tenant_id": "", - "client_id": "", - "client_secret": "", - }, - } - - provider := NewAzureAccessTokenProvider(cfg, authParams) - - t.Run("when managed identities enabled", func(t *testing.T) { - cfg.Azure.ManagedIdentityEnabled = true - - t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "msi", - } - - assert.True(t, provider.isManagedIdentityCredential()) - }) - - t.Run("should be client secret if auth type is client secret", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "clientsecret", - } - - assert.False(t, provider.isManagedIdentityCredential()) - }) - - t.Run("should be managed identity if datasource not configured", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "", - "tenant_id": "", - "client_id": "", - "client_secret": "", - } - - assert.True(t, provider.isManagedIdentityCredential()) - }) - - t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "", - "tenant_id": "06da9207-bdd9-4558-aee4-377450893cb4", - "client_id": "b8c58fe8-1fca-4e30-a0a8-b44d0e5f70d6", - "client_secret": "9bcd4434-824f-4887-a8a8-94c287bf0a7b", - } - - assert.False(t, provider.isManagedIdentityCredential()) - }) - }) - - t.Run("when managed identities disabled", func(t *testing.T) { - cfg.Azure.ManagedIdentityEnabled = false - - t.Run("should be managed identity if auth type is managed identity", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "msi", - } - - assert.True(t, provider.isManagedIdentityCredential()) - }) - - t.Run("should be client secret if datasource not configured", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "", - "tenant_id": "", - "client_id": "", - "client_secret": "", - } - - assert.False(t, provider.isManagedIdentityCredential()) - }) - }) -} - -func TestAzureTokenProvider_getAccessToken(t *testing.T) { +func TestAzureTokenProvider_GetAccessToken(t *testing.T) { ctx := context.Background() cfg := &setting.Cfg{} - authParams := &plugins.JwtTokenAuth{ - Scopes: []string{ - "https://management.azure.com/.default", - }, - Params: map[string]string{ - "azure_auth_type": "", - "azure_cloud": "AzureCloud", - "tenant_id": "", - "client_id": "", - "client_secret": "", - }, + scopes := []string{ + "https://management.azure.com/.default", } - provider := NewAzureAccessTokenProvider(cfg, authParams) - original := azureTokenCache azureTokenCache = &tokenCacheFake{} t.Cleanup(func() { azureTokenCache = original }) @@ -130,29 +35,31 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) { t.Run("when managed identities enabled", func(t *testing.T) { cfg.Azure.ManagedIdentityEnabled = true - t.Run("should resolve managed identity credential if auth type is managed identity", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "msi", + t.Run("should resolve managed identity retriever if auth type is managed identity", func(t *testing.T) { + credentials := &azcredentials.AzureManagedIdentityCredentials{} + + provider, err := NewAzureAccessTokenProvider(cfg, credentials) + require.NoError(t, err) + + getAccessTokenFunc = func(credential TokenRetriever, scopes []string) { + assert.IsType(t, &managedIdentityTokenRetriever{}, credential) } - getAccessTokenFunc = func(credential TokenCredential, scopes []string) { - assert.IsType(t, &managedIdentityCredential{}, credential) - } - - _, err := provider.GetAccessToken(ctx) + _, err = provider.GetAccessToken(ctx, scopes) require.NoError(t, err) }) - t.Run("should resolve client secret credential if auth type is client secret", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "clientsecret", + t.Run("should resolve client secret retriever if auth type is client secret", func(t *testing.T) { + credentials := &azcredentials.AzureClientSecretCredentials{} + + provider, err := NewAzureAccessTokenProvider(cfg, credentials) + require.NoError(t, err) + + getAccessTokenFunc = func(credential TokenRetriever, scopes []string) { + assert.IsType(t, &clientSecretTokenRetriever{}, credential) } - getAccessTokenFunc = func(credential TokenCredential, scopes []string) { - assert.IsType(t, &clientSecretCredential{}, credential) - } - - _, err := provider.GetAccessToken(ctx) + _, err = provider.GetAccessToken(ctx, scopes) require.NoError(t, err) }) }) @@ -161,47 +68,61 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) { cfg.Azure.ManagedIdentityEnabled = false t.Run("should return error if auth type is managed identity", func(t *testing.T) { - authParams.Params = map[string]string{ - "azure_auth_type": "msi", - } + credentials := &azcredentials.AzureManagedIdentityCredentials{} - getAccessTokenFunc = func(credential TokenCredential, scopes []string) { - assert.Fail(t, "token cache not expected to be called") - } - - _, err := provider.GetAccessToken(ctx) - require.Error(t, err) + _, err := NewAzureAccessTokenProvider(cfg, credentials) + assert.Error(t, err, "managed identity authentication is not enabled in Grafana config") }) }) } func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) { - cfg := &setting.Cfg{} - - authParams := &plugins.JwtTokenAuth{ - Scopes: []string{ - "https://management.azure.com/.default", - }, - Params: map[string]string{ - "azure_auth_type": "", - "azure_cloud": "AzureCloud", - "tenant_id": "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", - "client_id": "1af7c188-e5b6-4f96-81b8-911761bdd459", - "client_secret": "0416d95e-8af8-472c-aaa3-15c93c46080a", - }, + credentials := &azcredentials.AzureClientSecretCredentials{ + AzureCloud: setting.AzurePublic, + Authority: "", + TenantId: "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", + ClientId: "1af7c188-e5b6-4f96-81b8-911761bdd459", + ClientSecret: "0416d95e-8af8-472c-aaa3-15c93c46080a", } - provider := NewAzureAccessTokenProvider(cfg, authParams) + t.Run("should return clientSecretTokenRetriever with values", func(t *testing.T) { + result := getClientSecretTokenRetriever(credentials) + assert.IsType(t, &clientSecretTokenRetriever{}, result) - t.Run("should return clientSecretCredential with values", func(t *testing.T) { - result := provider.getClientSecretCredential() - assert.IsType(t, &clientSecretCredential{}, result) - - credential := (result).(*clientSecretCredential) + credential := (result).(*clientSecretTokenRetriever) assert.Equal(t, "https://login.microsoftonline.com/", credential.authority) assert.Equal(t, "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", credential.tenantId) assert.Equal(t, "1af7c188-e5b6-4f96-81b8-911761bdd459", credential.clientId) assert.Equal(t, "0416d95e-8af8-472c-aaa3-15c93c46080a", credential.clientSecret) }) + + t.Run("authority should selected based on cloud", func(t *testing.T) { + originalCloud := credentials.AzureCloud + defer func() { credentials.AzureCloud = originalCloud }() + + credentials.AzureCloud = setting.AzureChina + + result := getClientSecretTokenRetriever(credentials) + assert.IsType(t, &clientSecretTokenRetriever{}, result) + + credential := (result).(*clientSecretTokenRetriever) + + assert.Equal(t, "https://login.chinacloudapi.cn/", credential.authority) + }) + + t.Run("explicitly set authority should have priority over cloud", func(t *testing.T) { + originalCloud := credentials.AzureCloud + defer func() { credentials.AzureCloud = originalCloud }() + + credentials.AzureCloud = setting.AzureChina + credentials.Authority = "https://another.com/" + + result := getClientSecretTokenRetriever(credentials) + assert.IsType(t, &clientSecretTokenRetriever{}, result) + + credential := (result).(*clientSecretTokenRetriever) + + assert.Equal(t, "https://another.com/", credential.authority) + }) } diff --git a/pkg/tsdb/azuremonitor/azuremonitor.go b/pkg/tsdb/azuremonitor/azuremonitor.go index 0370f97e070..bb1bf986751 100644 --- a/pkg/tsdb/azuremonitor/azuremonitor.go +++ b/pkg/tsdb/azuremonitor/azuremonitor.go @@ -11,12 +11,14 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend/datasource" "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" + "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/backendplugin" "github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin" "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" ) const ( @@ -44,20 +46,15 @@ type Service struct { } type azureMonitorSettings struct { + SubscriptionId string `json:"subscriptionId"` + LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"` AppInsightsAppId string `json:"appInsightsAppId"` AzureLogAnalyticsSameAs bool `json:"azureLogAnalyticsSameAs"` - ClientId string `json:"clientId"` - CloudName string `json:"cloudName"` - LogAnalyticsClientId string `json:"logAnalyticsClientId"` - LogAnalyticsDefaultWorkspace string `json:"logAnalyticsDefaultWorkspace"` - LogAnalyticsSubscriptionId string `json:"logAnalyticsSubscriptionId"` - LogAnalyticsTenantId string `json:"logAnalyticsTenantId"` - SubscriptionId string `json:"subscriptionId"` - TenantId string `json:"tenantId"` - AzureAuthType string `json:"azureAuthType,omitempty"` } type datasourceInfo struct { + Cloud string + Credentials azcredentials.AzureCredentials Settings azureMonitorSettings Services map[string]datasourceService Routes map[string]azRoute @@ -74,10 +71,15 @@ type datasourceService struct { HTTPClient *http.Client } -func NewInstanceSettings() datasource.InstanceFactoryFunc { +func NewInstanceSettings(cfg *setting.Cfg) datasource.InstanceFactoryFunc { return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { - jsonData := map[string]interface{}{} - err := json.Unmarshal(settings.JSONData, &jsonData) + jsonData, err := simplejson.NewJson(settings.JSONData) + if err != nil { + return nil, fmt.Errorf("error reading settings: %w", err) + } + + jsonDataObj := map[string]interface{}{} + err = json.Unmarshal(settings.JSONData, &jsonDataObj) if err != nil { return nil, fmt.Errorf("error reading settings: %w", err) } @@ -87,20 +89,34 @@ func NewInstanceSettings() datasource.InstanceFactoryFunc { if err != nil { return nil, fmt.Errorf("error reading settings: %w", err) } + + cloud, err := getAzureCloud(cfg, jsonData) + if err != nil { + return nil, fmt.Errorf("error getting credentials: %w", err) + } + + credentials, err := getAzureCredentials(cfg, jsonData, settings.DecryptedSecureJSONData) + if err != nil { + return nil, fmt.Errorf("error getting credentials: %w", err) + } + httpCliOpts, err := settings.HTTPClientOptions() if err != nil { return nil, fmt.Errorf("error getting http options: %w", err) } model := datasourceInfo{ + Cloud: cloud, + Credentials: credentials, Settings: azMonitorSettings, - JSONData: jsonData, + JSONData: jsonDataObj, DecryptedSecureJSONData: settings.DecryptedSecureJSONData, DatasourceID: settings.ID, Services: map[string]datasourceService{}, - Routes: routes[azMonitorSettings.CloudName], + Routes: routes[cloud], HTTPCliOpts: httpCliOpts, } + return model, nil } } @@ -141,7 +157,7 @@ func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors ma } func (s *Service) Init() error { - im := datasource.NewInstanceManager(NewInstanceSettings()) + im := datasource.NewInstanceManager(NewInstanceSettings(s.Cfg)) executors := map[string]azDatasourceExecutor{ azureMonitor: &AzureMonitorDatasource{}, appInsights: &ApplicationInsightsDatasource{}, diff --git a/pkg/tsdb/azuremonitor/azuremonitor_test.go b/pkg/tsdb/azuremonitor/azuremonitor_test.go index 7fb4ce86eb0..93470b1c191 100644 --- a/pkg/tsdb/azuremonitor/azuremonitor_test.go +++ b/pkg/tsdb/azuremonitor/azuremonitor_test.go @@ -9,6 +9,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" "github.com/stretchr/testify/require" ) @@ -22,14 +23,16 @@ func TestNewInstanceSettings(t *testing.T) { { name: "creates an instance", settings: backend.DataSourceInstanceSettings{ - JSONData: []byte(`{"cloudName":"azuremonitor"}`), + JSONData: []byte(`{"azureAuthType":"msi"}`), DecryptedSecureJSONData: map[string]string{"key": "value"}, ID: 40, }, expectedModel: datasourceInfo{ - Settings: azureMonitorSettings{CloudName: "azuremonitor"}, - Routes: routes["azuremonitor"], - JSONData: map[string]interface{}{"cloudName": string("azuremonitor")}, + Cloud: setting.AzurePublic, + Credentials: &azcredentials.AzureManagedIdentityCredentials{}, + Settings: azureMonitorSettings{}, + Routes: routes[setting.AzurePublic], + JSONData: map[string]interface{}{"azureAuthType": "msi"}, DatasourceID: 40, DecryptedSecureJSONData: map[string]string{"key": "value"}, }, @@ -37,9 +40,15 @@ func TestNewInstanceSettings(t *testing.T) { }, } + cfg := &setting.Cfg{ + Azure: setting.AzureSettings{ + Cloud: setting.AzurePublic, + }, + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - factory := NewInstanceSettings() + factory := NewInstanceSettings(cfg) instance, err := factory(tt.settings) tt.Err(t, err) if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) { diff --git a/pkg/tsdb/azuremonitor/credentials.go b/pkg/tsdb/azuremonitor/credentials.go index 69ec1f93abd..1e8b6e08abf 100644 --- a/pkg/tsdb/azuremonitor/credentials.go +++ b/pkg/tsdb/azuremonitor/credentials.go @@ -1,12 +1,11 @@ package azuremonitor import ( - "net/http" + "fmt" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/plugins" + "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" ) // Azure cloud names specific to Azure Monitor @@ -17,40 +16,107 @@ const ( azureMonitorGermany = "germanyazuremonitor" ) -// Azure cloud query types -const ( - azureMonitor = "Azure Monitor" - appInsights = "Application Insights" - azureLogAnalytics = "Azure Log Analytics" - insightsAnalytics = "Insights Analytics" - azureResourceGraph = "Azure Resource Graph" -) - -func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider { - if len(route.Scopes) > 0 { - tokenAuth := &plugins.JwtTokenAuth{ - Url: route.URL, - Scopes: route.Scopes, - Params: map[string]string{ - "azure_auth_type": model.Settings.AzureAuthType, - "azure_cloud": cfg.Azure.Cloud, - "tenant_id": model.Settings.TenantId, - "client_id": model.Settings.ClientId, - "client_secret": model.DecryptedSecureJSONData["clientSecret"], - }, - } - tokenProvider := aztokenprovider.NewAzureAccessTokenProvider(cfg, tokenAuth) - return httpclient.NewProvider(httpclient.ProviderOptions{ - Middlewares: []httpclient.Middleware{ - aztokenprovider.AuthMiddleware(tokenProvider), - }, - }) +func getAuthType(cfg *setting.Cfg, jsonData *simplejson.Json) string { + if azureAuthType := jsonData.Get("azureAuthType").MustString(); azureAuthType != "" { + return azureAuthType } else { - return httpclient.NewProvider() + tenantId := jsonData.Get("tenantId").MustString() + clientId := jsonData.Get("clientId").MustString() + + // If authentication type isn't explicitly specified and datasource has client credentials, + // then this is existing datasource which is configured for app registration (client secret) + if tenantId != "" && clientId != "" { + return azcredentials.AzureAuthClientSecret + } + + // For newly created datasource with no configuration, managed identity is the default authentication type + // if they are enabled in Grafana config + if cfg.Azure.ManagedIdentityEnabled { + return azcredentials.AzureAuthManagedIdentity + } else { + return azcredentials.AzureAuthClientSecret + } } } -func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) { - model.HTTPCliOpts.Headers = route.Headers - return httpClientProvider(route, model, cfg).New(model.HTTPCliOpts) +func getDefaultAzureCloud(cfg *setting.Cfg) (string, error) { + // Allow only known cloud names + cloudName := cfg.Azure.Cloud + switch cloudName { + case setting.AzurePublic: + return setting.AzurePublic, nil + case setting.AzureChina: + return setting.AzureChina, nil + case setting.AzureUSGovernment: + return setting.AzureUSGovernment, nil + case setting.AzureGermany: + return setting.AzureGermany, nil + case "": + // Not set cloud defaults to public + return setting.AzurePublic, nil + default: + err := fmt.Errorf("the cloud '%s' not supported", cloudName) + return "", err + } +} + +func normalizeAzureCloud(cloudName string) (string, error) { + switch cloudName { + case azureMonitorPublic: + return setting.AzurePublic, nil + case azureMonitorChina: + return setting.AzureChina, nil + case azureMonitorUSGovernment: + return setting.AzureUSGovernment, nil + case azureMonitorGermany: + return setting.AzureGermany, nil + default: + err := fmt.Errorf("the cloud '%s' not supported", cloudName) + return "", err + } +} + +func getAzureCloud(cfg *setting.Cfg, jsonData *simplejson.Json) (string, error) { + authType := getAuthType(cfg, jsonData) + switch authType { + case azcredentials.AzureAuthManagedIdentity: + // In case of managed identity, the cloud is always same as where Grafana is hosted + return getDefaultAzureCloud(cfg) + case azcredentials.AzureAuthClientSecret: + if cloud := jsonData.Get("cloudName").MustString(); cloud != "" { + return normalizeAzureCloud(cloud) + } else { + return getDefaultAzureCloud(cfg) + } + default: + err := fmt.Errorf("the authentication type '%s' not supported", authType) + return "", err + } +} + +func getAzureCredentials(cfg *setting.Cfg, jsonData *simplejson.Json, secureJsonData map[string]string) (azcredentials.AzureCredentials, error) { + authType := getAuthType(cfg, jsonData) + + switch authType { + case azcredentials.AzureAuthManagedIdentity: + credentials := &azcredentials.AzureManagedIdentityCredentials{} + return credentials, nil + + case azcredentials.AzureAuthClientSecret: + cloud, err := getAzureCloud(cfg, jsonData) + if err != nil { + return nil, err + } + credentials := &azcredentials.AzureClientSecretCredentials{ + AzureCloud: cloud, + TenantId: jsonData.Get("tenantId").MustString(), + ClientId: jsonData.Get("clientId").MustString(), + ClientSecret: secureJsonData["clientSecret"], + } + return credentials, nil + + default: + err := fmt.Errorf("the authentication type '%s' not supported", authType) + return nil, err + } } diff --git a/pkg/tsdb/azuremonitor/credentials_test.go b/pkg/tsdb/azuremonitor/credentials_test.go index 8cb555a88ef..4cefa5697c3 100644 --- a/pkg/tsdb/azuremonitor/credentials_test.go +++ b/pkg/tsdb/azuremonitor/credentials_test.go @@ -3,50 +3,195 @@ package azuremonitor import ( "testing" + "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_httpCliProvider(t *testing.T) { +func TestCredentials_getAuthType(t *testing.T) { cfg := &setting.Cfg{} - model := datasourceInfo{ - Settings: azureMonitorSettings{}, - DecryptedSecureJSONData: map[string]string{"clientSecret": "content"}, - } - tests := []struct { - name string - route azRoute - expectedMiddlewares int - Err require.ErrorAssertionFunc - }{ - { - name: "creates an HTTP client with a middleware", - route: azRoute{ - URL: "http://route", - Scopes: []string{"http://route/.default"}, - }, - expectedMiddlewares: 1, - Err: require.NoError, - }, - { - name: "creates an HTTP client without a middleware", - route: azRoute{ - URL: "http://route", - Scopes: []string{}, - }, - // httpclient.NewProvider returns a client with 2 middlewares by default - expectedMiddlewares: 2, - Err: require.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cli := httpClientProvider(tt.route, model, cfg) - // Cannot test that the cli middleware works properly since the azcore sdk - // rejects the TLS certs (if provided) - if len(cli.Opts.Middlewares) != tt.expectedMiddlewares { - t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares) - } + + t.Run("when managed identities enabled", func(t *testing.T) { + cfg.Azure.ManagedIdentityEnabled = true + + t.Run("should be client secret if auth type is set to client secret", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthClientSecret, + }) + + authType := getAuthType(cfg, jsonData) + + assert.Equal(t, azcredentials.AzureAuthClientSecret, authType) }) - } + + t.Run("should be managed identity if datasource not configured", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": "", + }) + + authType := getAuthType(cfg, jsonData) + + assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType) + }) + + t.Run("should be client secret if auth type not specified but credentials configured", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": "", + "tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", + "clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6", + }) + + authType := getAuthType(cfg, jsonData) + + assert.Equal(t, azcredentials.AzureAuthClientSecret, authType) + }) + }) + + t.Run("when managed identities disabled", func(t *testing.T) { + cfg.Azure.ManagedIdentityEnabled = false + + t.Run("should be managed identity if auth type is set to managed identity", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthManagedIdentity, + }) + + authType := getAuthType(cfg, jsonData) + + assert.Equal(t, azcredentials.AzureAuthManagedIdentity, authType) + }) + + t.Run("should be client secret if datasource not configured", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": "", + }) + + authType := getAuthType(cfg, jsonData) + + assert.Equal(t, azcredentials.AzureAuthClientSecret, authType) + }) + }) +} + +func TestCredentials_getAzureCloud(t *testing.T) { + cfg := &setting.Cfg{ + Azure: setting.AzureSettings{ + Cloud: setting.AzureChina, + }, + } + + t.Run("when auth type is managed identity", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthManagedIdentity, + "cloudName": azureMonitorGermany, + }) + + t.Run("should be from server configuration regardless of datasource value", func(t *testing.T) { + cloud, err := getAzureCloud(cfg, jsonData) + require.NoError(t, err) + + assert.Equal(t, setting.AzureChina, cloud) + }) + + t.Run("should be public if not set in server configuration", func(t *testing.T) { + cfg := &setting.Cfg{ + Azure: setting.AzureSettings{ + Cloud: "", + }, + } + + cloud, err := getAzureCloud(cfg, jsonData) + require.NoError(t, err) + + assert.Equal(t, setting.AzurePublic, cloud) + }) + }) + + t.Run("when auth type is client secret", func(t *testing.T) { + t.Run("should be from datasource value normalized to known cloud name", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthClientSecret, + "cloudName": azureMonitorGermany, + }) + + cloud, err := getAzureCloud(cfg, jsonData) + require.NoError(t, err) + + assert.Equal(t, setting.AzureGermany, cloud) + }) + + t.Run("should be from server configuration if not set in datasource", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthClientSecret, + "cloudName": "", + }) + + cloud, err := getAzureCloud(cfg, jsonData) + require.NoError(t, err) + + assert.Equal(t, setting.AzureChina, cloud) + }) + }) +} + +func TestCredentials_getAzureCredentials(t *testing.T) { + cfg := &setting.Cfg{ + Azure: setting.AzureSettings{ + Cloud: setting.AzureChina, + }, + } + + secureJsonData := map[string]string{ + "clientSecret": "59e3498f-eb12-4943-b8f0-a5aa42640058", + } + + t.Run("when auth type is managed identity", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthManagedIdentity, + "cloudName": azureMonitorGermany, + "tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", + "clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6", + }) + + t.Run("should return managed identity credentials", func(t *testing.T) { + credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData) + require.NoError(t, err) + require.IsType(t, &azcredentials.AzureManagedIdentityCredentials{}, credentials) + msiCredentials := credentials.(*azcredentials.AzureManagedIdentityCredentials) + + // Azure Monitor datasource doesn't support user-assigned managed identities (ClientId is always empty) + assert.Equal(t, "", msiCredentials.ClientId) + }) + }) + + t.Run("when auth type is client secret", func(t *testing.T) { + jsonData := simplejson.NewFromAny(map[string]interface{}{ + "azureAuthType": azcredentials.AzureAuthClientSecret, + "cloudName": azureMonitorGermany, + "tenantId": "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", + "clientId": "849ccbb0-92eb-4226-b228-ef391abd8fe6", + }) + + t.Run("should return client secret credentials", func(t *testing.T) { + cfg := &setting.Cfg{ + Azure: setting.AzureSettings{ + Cloud: setting.AzureChina, + }, + } + + credentials, err := getAzureCredentials(cfg, jsonData, secureJsonData) + require.NoError(t, err) + require.IsType(t, &azcredentials.AzureClientSecretCredentials{}, credentials) + clientSecretCredentials := credentials.(*azcredentials.AzureClientSecretCredentials) + + assert.Equal(t, setting.AzureGermany, clientSecretCredentials.AzureCloud) + assert.Equal(t, "9b9d90ee-a5cc-49c2-b97e-0d1b0f086b5c", clientSecretCredentials.TenantId) + assert.Equal(t, "849ccbb0-92eb-4226-b228-ef391abd8fe6", clientSecretCredentials.ClientId) + assert.Equal(t, "59e3498f-eb12-4943-b8f0-a5aa42640058", clientSecretCredentials.ClientSecret) + + // Azure Monitor datasource doesn't support custom IdP authorities (Authority is always empty) + assert.Equal(t, "", clientSecretCredentials.Authority) + }) + }) } diff --git a/pkg/tsdb/azuremonitor/httpclient.go b/pkg/tsdb/azuremonitor/httpclient.go new file mode 100644 index 00000000000..1d4035f20e1 --- /dev/null +++ b/pkg/tsdb/azuremonitor/httpclient.go @@ -0,0 +1,41 @@ +package azuremonitor + +import ( + "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider" +) + +func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*httpclient.Provider, error) { + var clientProvider *httpclient.Provider + + if len(route.Scopes) > 0 { + tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, model.Credentials) + if err != nil { + return nil, err + } + + clientProvider = httpclient.NewProvider(httpclient.ProviderOptions{ + Middlewares: []httpclient.Middleware{ + aztokenprovider.AuthMiddleware(tokenProvider, route.Scopes), + }, + }) + } else { + clientProvider = httpclient.NewProvider() + } + + return clientProvider, nil +} + +func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) { + model.HTTPCliOpts.Headers = route.Headers + + clientProvider, err := httpClientProvider(route, model, cfg) + if err != nil { + return nil, err + } + + return clientProvider.New(model.HTTPCliOpts) +} diff --git a/pkg/tsdb/azuremonitor/httpclient_test.go b/pkg/tsdb/azuremonitor/httpclient_test.go new file mode 100644 index 00000000000..4add39aa855 --- /dev/null +++ b/pkg/tsdb/azuremonitor/httpclient_test.go @@ -0,0 +1,54 @@ +package azuremonitor + +import ( + "testing" + + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials" + "github.com/stretchr/testify/require" +) + +func Test_httpCliProvider(t *testing.T) { + cfg := &setting.Cfg{} + model := datasourceInfo{ + Credentials: &azcredentials.AzureClientSecretCredentials{}, + } + tests := []struct { + name string + route azRoute + expectedMiddlewares int + Err require.ErrorAssertionFunc + }{ + { + name: "creates an HTTP client with a middleware", + route: azRoute{ + URL: "http://route", + Scopes: []string{"http://route/.default"}, + }, + expectedMiddlewares: 1, + Err: require.NoError, + }, + { + name: "creates an HTTP client without a middleware", + route: azRoute{ + URL: "http://route", + Scopes: []string{}, + }, + // httpclient.NewProvider returns a client with 2 middlewares by default + expectedMiddlewares: 2, + Err: require.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cli, err := httpClientProvider(tt.route, model, cfg) + require.NoError(t, err) + + // Cannot test that the cli middleware works properly since the azcore sdk + // rejects the TLS certs (if provided) + if len(cli.Opts.Middlewares) != tt.expectedMiddlewares { + t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares) + } + }) + } +} diff --git a/pkg/tsdb/azuremonitor/routes.go b/pkg/tsdb/azuremonitor/routes.go index 6c25c99c24f..9c834f94955 100644 --- a/pkg/tsdb/azuremonitor/routes.go +++ b/pkg/tsdb/azuremonitor/routes.go @@ -1,5 +1,16 @@ package azuremonitor +import "github.com/grafana/grafana/pkg/setting" + +// Azure cloud query types +const ( + azureMonitor = "Azure Monitor" + appInsights = "Application Insights" + azureLogAnalytics = "Azure Log Analytics" + insightsAnalytics = "Insights Analytics" + azureResourceGraph = "Azure Resource Graph" +) + type azRoute struct { URL string Scopes []string @@ -64,22 +75,22 @@ var ( // The different Azure routes are identified by its cloud (e.g. public or gov) // and the service to query (e.g. Azure Monitor or Azure Log Analytics) routes = map[string]map[string]azRoute{ - azureMonitorPublic: { + setting.AzurePublic: { azureMonitor: azManagement, azureLogAnalytics: azLogAnalytics, azureResourceGraph: azManagement, appInsights: azAppInsights, insightsAnalytics: azAppInsights, }, - azureMonitorUSGovernment: { + setting.AzureUSGovernment: { azureMonitor: azUSGovManagement, azureLogAnalytics: azUSGovLogAnalytics, azureResourceGraph: azUSGovManagement, }, - azureMonitorGermany: { + setting.AzureGermany: { azureMonitor: azGermanyManagement, }, - azureMonitorChina: { + setting.AzureChina: { azureMonitor: azChinaManagement, azureLogAnalytics: azChinaLogAnalytics, azureResourceGraph: azChinaManagement,