AzureMonitor: token provider into aztokenprovider and cleanup (#36102)

This commit is contained in:
Sergey Kostrukov
2021-06-29 01:05:42 -07:00
committed by GitHub
parent 93cd375ada
commit 52e38c54e5
12 changed files with 68 additions and 696 deletions

View File

@@ -10,7 +10,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
"github.com/grafana/grafana/pkg/util"
)
@@ -93,7 +92,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
if tokenAuth == nil {
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
}
provider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
provider := newAzureAccessTokenProvider(ctx, cfg, tokenAuth)
return provider, nil
case "gce":

View File

@@ -1,457 +0,0 @@
package pluginproxy
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeCredential struct {
key string
initCalledTimes int
calledTimes int
initFunc func() error
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
func (c *fakeCredential) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeCredential) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
}
return nil
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
return c.getAccessTokenFunc(ctx, scopes)
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
}
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes1 := []string{"Scope1"}
scopes2 := []string{"Scope2"}
t.Run("should request access token from credential", func(t *testing.T) {
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token, err := cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token)
assert.Equal(t, 1, credential.calledTimes)
})
t.Run("should return cached token for same scopes", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
assert.Equal(t, 2, credential.calledTimes)
})
t.Run("should return cached token for same credentials", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential1 := &fakeCredential{key: "credential-1"}
credential2 := &fakeCredential{key: "credential-2"}
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
assert.Equal(t, 1, credential1.calledTimes)
assert.Equal(t, 1, credential2.calledTimes)
})
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when credential init returns error", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
return errors.New("unable to initialize")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
err := cacheEntry.ensureInitialized()
assert.Error(t, err)
})
t.Run("should call init again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
return errors.New("unable to initialize")
}
return nil
},
}
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, credential.initCalledTimes)
})
})
t.Run("when credential init panics", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call credential init again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
panic(errors.New("unable to initialize"))
}
return nil
},
}
t.Run("should call credential init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, credential.initCalledTimes)
})
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, "", accessToken)
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, credential.calledTimes)
})
})
t.Run("when credential getAccessToken panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
panic(errors.New("unable to get access token"))
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, credential.calledTimes)
})
})
}

View File

@@ -0,0 +1,25 @@
package pluginproxy
import (
"context"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
)
type azureAccessTokenProvider struct {
ctx context.Context
tokenProvider aztokenprovider.AzureTokenProvider
}
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
return &azureAccessTokenProvider{
ctx: ctx,
tokenProvider: aztokenprovider.NewAzureAccessTokenProvider(cfg, authParams),
}
}
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
return provider.tokenProvider.GetAccessToken(provider.ctx)
}

View File

@@ -1,30 +1,20 @@
package tokenprovider
package aztokenprovider
import (
"fmt"
"net/http"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
var (
// timeNow makes it possible to test usage of time
timeNow = time.Now
)
type TokenProvider interface {
GetAccessToken() (string, error)
}
const authenticationMiddlewareName = "AzureAuthentication"
func AuthMiddleware(tokenProvider TokenProvider) httpclient.Middleware {
func AuthMiddleware(tokenProvider AzureTokenProvider) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenProvider.GetAccessToken()
token, err := tokenProvider.GetAccessToken(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to retrieve azure access token: %w", err)
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return next.RoundTrip(req)

View File

@@ -1,4 +1,4 @@
package pluginproxy
package aztokenprovider
import (
"context"
@@ -9,6 +9,11 @@ import (
"time"
)
var (
// timeNow makes it possible to test usage of time
timeNow = time.Now
)
type AccessToken struct {
Token string
ExpiresOn time.Time
@@ -119,7 +124,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
c.cond.L.Lock()
for {
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
if c.accessToken != nil && c.accessToken.ExpiresOn.After(timeNow().Add(2*time.Minute)) {
// Use the cached token since it's available and not expired yet
accessToken = c.accessToken
break

View File

@@ -1,4 +1,4 @@
package tokenprovider
package aztokenprovider
import (
"context"

View File

@@ -1,4 +1,4 @@
package tokenprovider
package aztokenprovider
import (
"context"
@@ -16,22 +16,23 @@ var (
azureTokenCache = NewConcurrentTokenCache()
)
type azureAccessTokenProvider struct {
ctx context.Context
type AzureTokenProvider interface {
GetAccessToken(ctx context.Context) (string, error)
}
type tokenProviderImpl struct {
cfg *setting.Cfg
authParams *plugins.JwtTokenAuth
}
func NewAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg,
authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
return &azureAccessTokenProvider{
ctx: ctx,
func NewAzureAccessTokenProvider(cfg *setting.Cfg, authParams *plugins.JwtTokenAuth) *tokenProviderImpl {
return &tokenProviderImpl{
cfg: cfg,
authParams: authParams,
}
}
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context) (string, error) {
var credential TokenCredential
if provider.isManagedIdentityCredential() {
@@ -45,7 +46,7 @@ func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
credential = provider.getClientSecretCredential()
}
accessToken, err := azureTokenCache.GetAccessToken(provider.ctx, credential, provider.authParams.Scopes)
accessToken, err := azureTokenCache.GetAccessToken(ctx, credential, provider.authParams.Scopes)
if err != nil {
return "", err
}
@@ -53,7 +54,7 @@ func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
return accessToken, nil
}
func (provider *azureAccessTokenProvider) isManagedIdentityCredential() bool {
func (provider *tokenProviderImpl) isManagedIdentityCredential() bool {
authType := strings.ToLower(provider.authParams.Params["azure_auth_type"])
clientId := provider.authParams.Params["client_id"]
@@ -66,13 +67,13 @@ func (provider *azureAccessTokenProvider) isManagedIdentityCredential() bool {
return authType == "msi" || (authType == "" && clientId == "" && provider.cfg.Azure.ManagedIdentityEnabled)
}
func (provider *azureAccessTokenProvider) getManagedIdentityCredential() TokenCredential {
func (provider *tokenProviderImpl) getManagedIdentityCredential() TokenCredential {
clientId := provider.cfg.Azure.ManagedIdentityClientId
return &managedIdentityCredential{clientId: clientId}
}
func (provider *azureAccessTokenProvider) getClientSecretCredential() TokenCredential {
func (provider *tokenProviderImpl) getClientSecretCredential() TokenCredential {
authority := provider.resolveAuthorityHost(provider.authParams.Params["azure_cloud"])
tenantId := provider.authParams.Params["tenant_id"]
clientId := provider.authParams.Params["client_id"]
@@ -81,7 +82,7 @@ func (provider *azureAccessTokenProvider) getClientSecretCredential() TokenCrede
return &clientSecretCredential{authority: authority, tenantId: tenantId, clientId: clientId, clientSecret: clientSecret}
}
func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string) string {
func (provider *tokenProviderImpl) resolveAuthorityHost(cloudName string) string {
// Known Azure clouds
switch cloudName {
case setting.AzurePublic:

View File

@@ -1,4 +1,4 @@
package tokenprovider
package aztokenprovider
import (
"context"
@@ -14,14 +14,12 @@ var getAccessTokenFunc func(credential TokenCredential, scopes []string)
type tokenCacheFake struct{}
func (c *tokenCacheFake) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
getAccessTokenFunc(credential, scopes)
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
}
func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
ctx := context.Background()
cfg := &setting.Cfg{}
authParams := &plugins.JwtTokenAuth{
@@ -37,7 +35,7 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
},
}
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
provider := NewAzureAccessTokenProvider(cfg, authParams)
t.Run("when managed identities enabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = true
@@ -123,7 +121,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
},
}
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
provider := NewAzureAccessTokenProvider(cfg, authParams)
original := azureTokenCache
azureTokenCache = &tokenCacheFake{}
@@ -141,7 +139,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.IsType(t, &managedIdentityCredential{}, credential)
}
_, err := provider.GetAccessToken()
_, err := provider.GetAccessToken(ctx)
require.NoError(t, err)
})
@@ -154,7 +152,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.IsType(t, &clientSecretCredential{}, credential)
}
_, err := provider.GetAccessToken()
_, err := provider.GetAccessToken(ctx)
require.NoError(t, err)
})
})
@@ -171,15 +169,13 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.Fail(t, "token cache not expected to be called")
}
_, err := provider.GetAccessToken()
_, err := provider.GetAccessToken(ctx)
require.Error(t, err)
})
})
}
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
ctx := context.Background()
cfg := &setting.Cfg{}
authParams := &plugins.JwtTokenAuth{
@@ -195,7 +191,7 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
},
}
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
provider := NewAzureAccessTokenProvider(cfg, authParams)
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
result := provider.getClientSecretCredential()

View File

@@ -125,7 +125,7 @@ func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors ma
if _, ok := dsInfo.Services[dst]; !ok {
// Create an HTTP Client if it has not been created before
route := dsInfo.Routes[dst]
client, err := newHTTPClient(ctx, route, dsInfo, cfg)
client, err := newHTTPClient(route, dsInfo, cfg)
if err != nil {
return nil, err
}

View File

@@ -1,13 +1,12 @@
package azuremonitor
import (
"context"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
)
// Azure cloud names specific to Azure Monitor
@@ -27,7 +26,7 @@ const (
azureResourceGraph = "Azure Resource Graph"
)
func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
func httpClientProvider(route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
if len(route.Scopes) > 0 {
tokenAuth := &plugins.JwtTokenAuth{
Url: route.URL,
@@ -40,10 +39,10 @@ func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
},
}
tokenProvider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
tokenProvider := aztokenprovider.NewAzureAccessTokenProvider(cfg, tokenAuth)
return httpclient.NewProvider(httpclient.ProviderOptions{
Middlewares: []httpclient.Middleware{
tokenprovider.AuthMiddleware(tokenProvider),
aztokenprovider.AuthMiddleware(tokenProvider),
},
})
} else {
@@ -51,7 +50,7 @@ func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo
}
}
func newHTTPClient(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
func newHTTPClient(route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
model.HTTPCliOpts.Headers = route.Headers
return httpClientProvider(ctx, route, model, cfg).New(model.HTTPCliOpts)
return httpClientProvider(route, model, cfg).New(model.HTTPCliOpts)
}

View File

@@ -1,7 +1,6 @@
package azuremonitor
import (
"context"
"testing"
"github.com/grafana/grafana/pkg/setting"
@@ -9,7 +8,6 @@ import (
)
func Test_httpCliProvider(t *testing.T) {
ctx := context.TODO()
cfg := &setting.Cfg{}
model := datasourceInfo{
Settings: azureMonitorSettings{},
@@ -43,7 +41,7 @@ func Test_httpCliProvider(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli := httpClientProvider(ctx, tt.route, model, cfg)
cli := httpClientProvider(tt.route, model, cfg)
// Cannot test that the cli middleware works properly since the azcore sdk
// rejects the TLS certs (if provided)
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {

View File

@@ -1,184 +0,0 @@
package tokenprovider
import (
"context"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
type AccessToken struct {
Token string
ExpiresOn time.Time
}
type TokenCredential interface {
GetCacheKey() string
Init() error
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
}
type ConcurrentTokenCache interface {
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
}
func NewConcurrentTokenCache() ConcurrentTokenCache {
return &tokenCacheImpl{}
}
type tokenCacheImpl struct {
cache sync.Map // of *credentialCacheEntry
}
type credentialCacheEntry struct {
credential TokenCredential
credInit uint32
credMutex sync.Mutex
cache sync.Map // of *scopesCacheEntry
}
type scopesCacheEntry struct {
credential TokenCredential
scopes []string
cond *sync.Cond
refreshing bool
accessToken *AccessToken
}
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
}
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
var entry interface{}
var ok bool
key := credential.GetCacheKey()
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
credential: credential,
})
}
return entry.(*credentialCacheEntry)
}
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
err := c.ensureInitialized()
if err != nil {
return "", err
}
return c.getEntryFor(scopes).getAccessToken(ctx)
}
func (c *credentialCacheEntry) ensureInitialized() error {
if atomic.LoadUint32(&c.credInit) == 0 {
c.credMutex.Lock()
defer c.credMutex.Unlock()
if c.credInit == 0 {
// Initialize credential
err := c.credential.Init()
if err != nil {
return err
}
atomic.StoreUint32(&c.credInit, 1)
}
}
return nil
}
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
var entry interface{}
var ok bool
key := getKeyForScopes(scopes)
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
credential: c.credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
})
}
return entry.(*scopesCacheEntry)
}
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
var accessToken *AccessToken
var err error
shouldRefresh := false
c.cond.L.Lock()
for {
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
// Use the cached token since it's available and not expired yet
accessToken = c.accessToken
break
}
if !c.refreshing {
// Start refreshing the token
c.refreshing = true
shouldRefresh = true
break
}
// Wait for the token to be refreshed
c.cond.Wait()
}
c.cond.L.Unlock()
if shouldRefresh {
accessToken, err = c.refreshAccessToken(ctx)
if err != nil {
return "", err
}
}
return accessToken.Token, nil
}
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
var accessToken *AccessToken
// Safeguarding from panic caused by credential implementation
defer func() {
c.cond.L.Lock()
c.refreshing = false
if accessToken != nil {
c.accessToken = accessToken
}
c.cond.Broadcast()
c.cond.L.Unlock()
}()
token, err := c.credential.GetAccessToken(ctx, c.scopes)
if err != nil {
return nil, err
}
accessToken = token
return accessToken, nil
}
func getKeyForScopes(scopes []string) string {
if len(scopes) > 1 {
arr := make([]string, len(scopes))
copy(arr, scopes)
sort.Strings(arr)
scopes = arr
}
return strings.Join(scopes, " ")
}