mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AzureMonitor: token provider into aztokenprovider and cleanup (#36102)
This commit is contained in:
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
@@ -93,7 +92,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
|
||||
if tokenAuth == nil {
|
||||
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
|
||||
}
|
||||
provider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
provider := newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
return provider, nil
|
||||
|
||||
case "gce":
|
||||
|
||||
@@ -1,457 +0,0 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeCredential struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
initFunc func() error
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
return c.getAccessTokenFunc(ctx, scopes)
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
}
|
||||
|
||||
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes1 := []string{"Scope1"}
|
||||
scopes2 := []string{"Scope2"}
|
||||
|
||||
t.Run("should request access token from credential", func(t *testing.T) {
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token, err := cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token)
|
||||
|
||||
assert.Equal(t, 1, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeCredential{key: "credential-1"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same credentials", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential1 := &fakeCredential{key: "credential-1"}
|
||||
credential2 := &fakeCredential{key: "credential-2"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
assert.Equal(t, 1, credential1.calledTimes)
|
||||
assert.Equal(t, 1, credential2.calledTimes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when credential init returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
return errors.New("unable to initialize")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to initialize"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", accessToken)
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to get access token"))
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
25
pkg/api/pluginproxy/token_provider_azure.go
Normal file
25
pkg/api/pluginproxy/token_provider_azure.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
ctx context.Context
|
||||
tokenProvider aztokenprovider.AzureTokenProvider
|
||||
}
|
||||
|
||||
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
|
||||
return &azureAccessTokenProvider{
|
||||
ctx: ctx,
|
||||
tokenProvider: aztokenprovider.NewAzureAccessTokenProvider(cfg, authParams),
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
return provider.tokenProvider.GetAccessToken(provider.ctx)
|
||||
}
|
||||
@@ -1,30 +1,20 @@
|
||||
package tokenprovider
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type TokenProvider interface {
|
||||
GetAccessToken() (string, error)
|
||||
}
|
||||
|
||||
const authenticationMiddlewareName = "AzureAuthentication"
|
||||
|
||||
func AuthMiddleware(tokenProvider TokenProvider) httpclient.Middleware {
|
||||
func AuthMiddleware(tokenProvider AzureTokenProvider) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken()
|
||||
token, err := tokenProvider.GetAccessToken(req.Context())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve azure access token: %w", err)
|
||||
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return next.RoundTrip(req)
|
||||
@@ -1,4 +1,4 @@
|
||||
package pluginproxy
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiresOn time.Time
|
||||
@@ -119,7 +124,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
|
||||
c.cond.L.Lock()
|
||||
for {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(timeNow().Add(2*time.Minute)) {
|
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken
|
||||
break
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenprovider
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenprovider
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -16,22 +16,23 @@ var (
|
||||
azureTokenCache = NewConcurrentTokenCache()
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
ctx context.Context
|
||||
type AzureTokenProvider interface {
|
||||
GetAccessToken(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
type tokenProviderImpl struct {
|
||||
cfg *setting.Cfg
|
||||
authParams *plugins.JwtTokenAuth
|
||||
}
|
||||
|
||||
func NewAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg,
|
||||
authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
|
||||
return &azureAccessTokenProvider{
|
||||
ctx: ctx,
|
||||
func NewAzureAccessTokenProvider(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *tokenProviderImpl {
|
||||
return &tokenProviderImpl{
|
||||
cfg: cfg,
|
||||
authParams: authParams,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context) (string, error) {
|
||||
var credential TokenCredential
|
||||
|
||||
if provider.isManagedIdentityCredential() {
|
||||
@@ -45,7 +46,7 @@ func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
credential = provider.getClientSecretCredential()
|
||||
}
|
||||
|
||||
accessToken, err := azureTokenCache.GetAccessToken(provider.ctx, credential, provider.authParams.Scopes)
|
||||
accessToken, err := azureTokenCache.GetAccessToken(ctx, credential, provider.authParams.Scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -53,7 +54,7 @@ func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) isManagedIdentityCredential() bool {
|
||||
func (provider *tokenProviderImpl) isManagedIdentityCredential() bool {
|
||||
authType := strings.ToLower(provider.authParams.Params["azure_auth_type"])
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
|
||||
@@ -66,13 +67,13 @@ func (provider *azureAccessTokenProvider) isManagedIdentityCredential() bool {
|
||||
return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled)
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getManagedIdentityCredential() TokenCredential {
|
||||
func (provider *tokenProviderImpl) getManagedIdentityCredential() TokenCredential {
|
||||
clientId := provider.cfg.Azure.ManagedIdentityClientId
|
||||
|
||||
return &managedIdentityCredential{clientId: clientId}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) getClientSecretCredential() TokenCredential {
|
||||
func (provider *tokenProviderImpl) getClientSecretCredential() TokenCredential {
|
||||
authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"])
|
||||
tenantId := provider.authParams.Params["tenant_id"]
|
||||
clientId := provider.authParams.Params["client_id"]
|
||||
@@ -81,7 +82,7 @@ func (provider *azureAccessTokenProvider) getClientSecretCredential() TokenCrede
|
||||
return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret}
|
||||
}
|
||||
|
||||
func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string) string {
|
||||
func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string {
|
||||
// Known Azure clouds
|
||||
switch cloudName {
|
||||
case setting.AzurePublic:
|
||||
@@ -1,4 +1,4 @@
|
||||
package tokenprovider
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,14 +14,12 @@ var getAccessTokenFunc func(credential TokenCredential, scopes []string)
|
||||
|
||||
type tokenCacheFake struct{}
|
||||
|
||||
func (c *tokenCacheFake) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
getAccessTokenFunc(credential, scopes)
|
||||
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
@@ -37,7 +35,7 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
cfg.Azure.ManagedIdentityEnabled = true
|
||||
@@ -123,7 +121,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
|
||||
original := azureTokenCache
|
||||
azureTokenCache = &tokenCacheFake{}
|
||||
@@ -141,7 +139,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.IsType(t, &managedIdentityCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken()
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -154,7 +152,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.IsType(t, &clientSecretCredential{}, credential)
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken()
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
@@ -171,15 +169,13 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
|
||||
assert.Fail(t, "token cache not expected to be called")
|
||||
}
|
||||
|
||||
_, err := provider.GetAccessToken()
|
||||
_, err := provider.GetAccessToken(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
|
||||
authParams := &plugins.JwtTokenAuth{
|
||||
@@ -195,7 +191,7 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
|
||||
provider := NewAzureAccessTokenProvider(cfg, authParams)
|
||||
|
||||
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
|
||||
result := provider.getClientSecretCredential()
|
||||
@@ -125,7 +125,7 @@ func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors ma
|
||||
if _, ok := dsInfo.Services[dst]; !ok {
|
||||
// Create an HTTP Client if it has not been created before
|
||||
route := dsInfo.Routes[dst]
|
||||
client, err := newHTTPClient(ctx, route, dsInfo, cfg)
|
||||
client, err := newHTTPClient(route, dsInfo, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
// Azure cloud names specific to Azure Monitor
|
||||
@@ -27,7 +26,7 @@ const (
|
||||
azureResourceGraph = "Azure Resource Graph"
|
||||
)
|
||||
|
||||
func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
|
||||
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
|
||||
if len(route.Scopes) > 0 {
|
||||
tokenAuth := &plugins.JwtTokenAuth{
|
||||
Url: route.URL,
|
||||
@@ -40,10 +39,10 @@ func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo
|
||||
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
|
||||
},
|
||||
}
|
||||
tokenProvider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
|
||||
tokenProvider := aztokenprovider.NewAzureAccessTokenProvider(cfg, tokenAuth)
|
||||
return httpclient.NewProvider(httpclient.ProviderOptions{
|
||||
Middlewares: []httpclient.Middleware{
|
||||
tokenprovider.AuthMiddleware(tokenProvider),
|
||||
aztokenprovider.AuthMiddleware(tokenProvider),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -51,7 +50,7 @@ func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPClient(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
|
||||
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
|
||||
model.HTTPCliOpts.Headers = route.Headers
|
||||
return httpClientProvider(ctx, route, model, cfg).New(model.HTTPCliOpts)
|
||||
return httpClientProvider(route, model, cfg).New(model.HTTPCliOpts)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func Test_httpCliProvider(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
cfg := &setting.Cfg{}
|
||||
model := datasourceInfo{
|
||||
Settings: azureMonitorSettings{},
|
||||
@@ -43,7 +41,7 @@ func Test_httpCliProvider(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := httpClientProvider(ctx, tt.route, model, cfg)
|
||||
cli := httpClientProvider(tt.route, model, cfg)
|
||||
// Cannot test that the cli middleware works properly since the azcore sdk
|
||||
// rejects the TLS certs (if provided)
|
||||
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
package tokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenCredential interface {
|
||||
GetCacheKey() string
|
||||
Init() error
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
type ConcurrentTokenCache interface {
|
||||
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache {
|
||||
return &tokenCacheImpl{}
|
||||
}
|
||||
|
||||
type tokenCacheImpl struct {
|
||||
cache sync.Map // of *credentialCacheEntry
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
credential TokenCredential
|
||||
|
||||
credInit uint32
|
||||
credMutex sync.Mutex
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
credential TokenCredential
|
||||
scopes []string
|
||||
|
||||
cond *sync.Cond
|
||||
refreshing bool
|
||||
accessToken *AccessToken
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := credential.GetCacheKey()
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
|
||||
credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*credentialCacheEntry)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
err := c.ensureInitialized()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.getEntryFor(scopes).getAccessToken(ctx)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) ensureInitialized() error {
|
||||
if atomic.LoadUint32(&c.credInit) == 0 {
|
||||
c.credMutex.Lock()
|
||||
defer c.credMutex.Unlock()
|
||||
|
||||
if c.credInit == 0 {
|
||||
// Initialize credential
|
||||
err := c.credential.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.credInit, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := getKeyForScopes(scopes)
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
|
||||
credential: c.credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*scopesCacheEntry)
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
var accessToken *AccessToken
|
||||
var err error
|
||||
shouldRefresh := false
|
||||
|
||||
c.cond.L.Lock()
|
||||
for {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
|
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken
|
||||
break
|
||||
}
|
||||
|
||||
if !c.refreshing {
|
||||
// Start refreshing the token
|
||||
c.refreshing = true
|
||||
shouldRefresh = true
|
||||
break
|
||||
}
|
||||
|
||||
// Wait for the token to be refreshed
|
||||
c.cond.Wait()
|
||||
}
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if shouldRefresh {
|
||||
accessToken, err = c.refreshAccessToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken.Token, nil
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
|
||||
var accessToken *AccessToken
|
||||
|
||||
// Safeguarding from panic caused by credential implementation
|
||||
defer func() {
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
|
||||
if accessToken != nil {
|
||||
c.accessToken = accessToken
|
||||
}
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
}()
|
||||
|
||||
token, err := c.credential.GetAccessToken(ctx, c.scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken = token
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func getKeyForScopes(scopes []string) string {
|
||||
if len(scopes) > 1 {
|
||||
arr := make([]string, len(scopes))
|
||||
copy(arr, scopes)
|
||||
sort.Strings(arr)
|
||||
scopes = arr
|
||||
}
|
||||
|
||||
return strings.Join(scopes, " ")
|
||||
}
|
||||
Reference in New Issue
Block a user