mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Plugins: Fix Azure token provider cache panic and auth param nil value (#34252)
* More tests for token cache * Safeguarding from panic and concurrency fixes * Update Azure dependencies * Fix interpolation of empty plugin data
This commit is contained in:
@@ -9,25 +9,20 @@ import (
|
||||
)
|
||||
|
||||
func TestApplyRoute_interpolateAuthParams(t *testing.T) {
|
||||
pluginRoute := &plugins.AppPluginRoute{
|
||||
Path: "pathwithjwttoken1",
|
||||
URL: "https://api.jwt.io/some/path",
|
||||
Method: "GET",
|
||||
TokenAuth: &plugins.JwtTokenAuth{
|
||||
Url: "https://login.server.com/{{.JsonData.tenantId}}/oauth2/token",
|
||||
Scopes: []string{
|
||||
"https://www.testapi.com/auth/Read.All",
|
||||
"https://www.testapi.com/auth/Write.All",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"token_uri": "{{.JsonData.tokenUri}}",
|
||||
"client_email": "{{.JsonData.clientEmail}}",
|
||||
"private_key": "{{.SecureJsonData.privateKey}}",
|
||||
},
|
||||
tokenAuth := &plugins.JwtTokenAuth{
|
||||
Url: "https://login.server.com/{{.JsonData.tenantId}}/oauth2/token",
|
||||
Scopes: []string{
|
||||
"https://www.testapi.com/auth/Read.All",
|
||||
"https://www.testapi.com/auth/Write.All",
|
||||
},
|
||||
Params: map[string]string{
|
||||
"token_uri": "{{.JsonData.tokenUri}}",
|
||||
"client_email": "{{.JsonData.clientEmail | orEmpty}}",
|
||||
"private_key": "{{.SecureJsonData.privateKey | orEmpty}}",
|
||||
},
|
||||
}
|
||||
|
||||
templateData := templateData{
|
||||
validData := templateData{
|
||||
JsonData: map[string]interface{}{
|
||||
"clientEmail": "test@test.com",
|
||||
"tokenUri": "login.url.com/token",
|
||||
@@ -38,8 +33,13 @@ func TestApplyRoute_interpolateAuthParams(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
emptyData := templateData{
|
||||
JsonData: map[string]interface{}{},
|
||||
SecureJsonData: map[string]string{},
|
||||
}
|
||||
|
||||
t.Run("should interpolate JwtTokenAuth struct using given JsonData", func(t *testing.T) {
|
||||
interpolated, err := interpolateAuthParams(pluginRoute.TokenAuth, templateData)
|
||||
interpolated, err := interpolateAuthParams(tokenAuth, validData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, interpolated)
|
||||
|
||||
@@ -55,8 +55,27 @@ func TestApplyRoute_interpolateAuthParams(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should return Nil if given JwtTokenAuth is Nil", func(t *testing.T) {
|
||||
interpolated, err := interpolateAuthParams(pluginRoute.JwtTokenAuth, templateData)
|
||||
interpolated, err := interpolateAuthParams(nil, validData)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, interpolated)
|
||||
})
|
||||
|
||||
t.Run("when plugin data is empty", func(t *testing.T) {
|
||||
interpolated, err := interpolateAuthParams(tokenAuth, emptyData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, interpolated)
|
||||
|
||||
t.Run("template expressions in url should resolve to <no value>", func(t *testing.T) {
|
||||
assert.Equal(t, "https://login.server.com/<no value>/oauth2/token", interpolated.Url)
|
||||
})
|
||||
|
||||
t.Run("template expressions in params resolve to <no value>", func(t *testing.T) {
|
||||
assert.Equal(t, "<no value>", interpolated.Params["token_uri"])
|
||||
})
|
||||
|
||||
t.Run("template expressions with orEmpty should resolve to empty string", func(t *testing.T) {
|
||||
assert.Equal(t, "", interpolated.Params["client_email"])
|
||||
assert.Equal(t, "", interpolated.Params["private_key"])
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,16 +97,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if shouldRefresh {
|
||||
accessToken, err = c.credential.GetAccessToken(ctx, c.scopes)
|
||||
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
c.accessToken = accessToken
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
|
||||
accessToken, err = c.refreshAccessToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -115,6 +106,31 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
return accessToken.Token, nil
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
|
||||
var accessToken *AccessToken
|
||||
|
||||
// Safeguarding from panic caused by credential implementation
|
||||
defer func() {
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
|
||||
if accessToken != nil {
|
||||
c.accessToken = accessToken
|
||||
}
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
}()
|
||||
|
||||
token, err := c.credential.GetAccessToken(ctx, c.scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken = token
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func getKeyForScopes(scopes []string) string {
|
||||
if len(scopes) > 1 {
|
||||
arr := make([]string, len(scopes))
|
||||
|
||||
@@ -2,7 +2,9 @@ package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -100,3 +102,185 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
assert.Equal(t, 1, credential2.calledTimes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", accessToken)
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to get access token"))
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
@@ -44,7 +46,7 @@ func (provider *azureAccessTokenProvider) getAccessToken() (string, error) {
|
||||
|
||||
if provider.isManagedIdentityCredential() {
|
||||
if !provider.cfg.Azure.ManagedIdentityEnabled {
|
||||
err := fmt.Errorf("managed identity authentication not enabled in Grafana config")
|
||||
err := fmt.Errorf("managed identity authentication is not enabled in Grafana config")
|
||||
return "", err
|
||||
} else {
|
||||
credential = provider.getManagedIdentityCredential()
|
||||
@@ -106,8 +108,9 @@ func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string)
|
||||
}
|
||||
|
||||
type managedIdentityCredential struct {
|
||||
clientId string
|
||||
credential azcore.TokenCredential
|
||||
clientId string
|
||||
credLock sync.Mutex
|
||||
credValue atomic.Value // of azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
@@ -118,14 +121,29 @@ func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|msi|%s", clientId)
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
// No need to lock here because the caller is responsible for thread safety
|
||||
if c.credential == nil {
|
||||
func (c *managedIdentityCredential) getCredential() (azcore.TokenCredential, error) {
|
||||
credential := c.credValue.Load()
|
||||
|
||||
if credential == nil {
|
||||
c.credLock.Lock()
|
||||
defer c.credLock.Unlock()
|
||||
|
||||
var err error
|
||||
c.credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
|
||||
credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.credValue.Store(credential)
|
||||
}
|
||||
|
||||
return credential.(azcore.TokenCredential), nil
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
credential, err := c.getCredential()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource
|
||||
@@ -135,7 +153,7 @@ func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes [
|
||||
resource := strings.TrimSuffix(scopes[0], "/.default")
|
||||
scopes = []string{resource}
|
||||
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,24 +166,40 @@ type clientSecretCredential struct {
|
||||
tenantId string
|
||||
clientId string
|
||||
clientSecret string
|
||||
credential azcore.TokenCredential
|
||||
credLock sync.Mutex
|
||||
credValue atomic.Value // of azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
// No need to lock here because the caller is responsible for thread safety
|
||||
if c.credential == nil {
|
||||
func (c *clientSecretCredential) getCredential() (azcore.TokenCredential, error) {
|
||||
credential := c.credValue.Load()
|
||||
|
||||
if credential == nil {
|
||||
c.credLock.Lock()
|
||||
defer c.credLock.Unlock()
|
||||
|
||||
var err error
|
||||
c.credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
|
||||
credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.credValue.Store(credential)
|
||||
}
|
||||
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
return credential.(azcore.TokenCredential), nil
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
credential, err := c.getCredential()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,16 @@ import (
|
||||
|
||||
// interpolateString accepts template data and return a string with substitutions
|
||||
func interpolateString(text string, data templateData) (string, error) {
|
||||
t, err := template.New("content").Parse(text)
|
||||
extraFuncs := map[string]interface{}{
|
||||
"orEmpty": func(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return v
|
||||
},
|
||||
}
|
||||
|
||||
t, err := template.New("content").Funcs(extraFuncs).Parse(text)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not parse template %s", text)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user