Plugins: Fix Azure token provider cache panic and auth param nil value (#34252)

* More tests for token cache

* Safeguarding from panic and concurrency fixes

* Update Azure dependencies

* Fix interpolation of empty plugin data
This commit is contained in:
Sergey Kostrukov
2021-05-18 06:36:58 -07:00
committed by GitHub
parent 63b2dd06a5
commit c1b8a10f41
8 changed files with 363 additions and 86 deletions

View File

@@ -9,25 +9,20 @@ import (
)
func TestApplyRoute_interpolateAuthParams(t *testing.T) {
pluginRoute := &plugins.AppPluginRoute{
Path: "pathwithjwttoken1",
URL: "https://api.jwt.io/some/path",
Method: "GET",
TokenAuth: &plugins.JwtTokenAuth{
Url: "https://login.server.com/{{.JsonData.tenantId}}/oauth2/token",
Scopes: []string{
"https://www.testapi.com/auth/Read.All",
"https://www.testapi.com/auth/Write.All",
},
Params: map[string]string{
"token_uri": "{{.JsonData.tokenUri}}",
"client_email": "{{.JsonData.clientEmail}}",
"private_key": "{{.SecureJsonData.privateKey}}",
},
tokenAuth := &plugins.JwtTokenAuth{
Url: "https://login.server.com/{{.JsonData.tenantId}}/oauth2/token",
Scopes: []string{
"https://www.testapi.com/auth/Read.All",
"https://www.testapi.com/auth/Write.All",
},
Params: map[string]string{
"token_uri": "{{.JsonData.tokenUri}}",
"client_email": "{{.JsonData.clientEmail | orEmpty}}",
"private_key": "{{.SecureJsonData.privateKey | orEmpty}}",
},
}
templateData := templateData{
validData := templateData{
JsonData: map[string]interface{}{
"clientEmail": "test@test.com",
"tokenUri": "login.url.com/token",
@@ -38,8 +33,13 @@ func TestApplyRoute_interpolateAuthParams(t *testing.T) {
},
}
emptyData := templateData{
JsonData: map[string]interface{}{},
SecureJsonData: map[string]string{},
}
t.Run("should interpolate JwtTokenAuth struct using given JsonData", func(t *testing.T) {
interpolated, err := interpolateAuthParams(pluginRoute.TokenAuth, templateData)
interpolated, err := interpolateAuthParams(tokenAuth, validData)
require.NoError(t, err)
require.NotNil(t, interpolated)
@@ -55,8 +55,27 @@ func TestApplyRoute_interpolateAuthParams(t *testing.T) {
})
t.Run("should return Nil if given JwtTokenAuth is Nil", func(t *testing.T) {
interpolated, err := interpolateAuthParams(pluginRoute.JwtTokenAuth, templateData)
interpolated, err := interpolateAuthParams(nil, validData)
require.NoError(t, err)
require.Nil(t, interpolated)
})
t.Run("when plugin data is empty", func(t *testing.T) {
interpolated, err := interpolateAuthParams(tokenAuth, emptyData)
require.NoError(t, err)
require.NotNil(t, interpolated)
t.Run("template expressions in url should resolve to <no value>", func(t *testing.T) {
assert.Equal(t, "https://login.server.com/<no value>/oauth2/token", interpolated.Url)
})
t.Run("template expressions in params resolve to <no value>", func(t *testing.T) {
assert.Equal(t, "<no value>", interpolated.Params["token_uri"])
})
t.Run("template expressions with orEmpty should resolve to empty string", func(t *testing.T) {
assert.Equal(t, "", interpolated.Params["client_email"])
assert.Equal(t, "", interpolated.Params["private_key"])
})
})
}

View File

@@ -97,16 +97,7 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
c.cond.L.Unlock()
if shouldRefresh {
accessToken, err = c.credential.GetAccessToken(ctx, c.scopes)
c.cond.L.Lock()
c.refreshing = false
c.accessToken = accessToken
c.cond.Broadcast()
c.cond.L.Unlock()
accessToken, err = c.refreshAccessToken(ctx)
if err != nil {
return "", err
}
@@ -115,6 +106,31 @@ func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
return accessToken.Token, nil
}
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
var accessToken *AccessToken
// Safeguarding from panic caused by credential implementation
defer func() {
c.cond.L.Lock()
c.refreshing = false
if accessToken != nil {
c.accessToken = accessToken
}
c.cond.Broadcast()
c.cond.L.Unlock()
}()
token, err := c.credential.GetAccessToken(ctx, c.scopes)
if err != nil {
return nil, err
}
accessToken = token
return accessToken, nil
}
func getKeyForScopes(scopes []string) string {
if len(scopes) > 1 {
arr := make([]string, len(scopes))

View File

@@ -2,7 +2,9 @@ package pluginproxy
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
@@ -100,3 +102,185 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
assert.Equal(t, 1, credential2.calledTimes)
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, "", accessToken)
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.calledTimes = 0
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, credential.calledTimes)
})
})
t.Run("when credential panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.calledTimes = 0
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
panic(errors.New("unable to get access token"))
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, credential.calledTimes)
})
})
}

View File

@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -44,7 +46,7 @@ func (provider *azureAccessTokenProvider) getAccessToken() (string, error) {
if provider.isManagedIdentityCredential() {
if !provider.cfg.Azure.ManagedIdentityEnabled {
err := fmt.Errorf("managed identity authentication not enabled in Grafana config")
err := fmt.Errorf("managed identity authentication is not enabled in Grafana config")
return "", err
} else {
credential = provider.getManagedIdentityCredential()
@@ -106,8 +108,9 @@ func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string)
}
type managedIdentityCredential struct {
clientId string
credential azcore.TokenCredential
clientId string
credLock sync.Mutex
credValue atomic.Value // of azcore.TokenCredential
}
func (c *managedIdentityCredential) GetCacheKey() string {
@@ -118,14 +121,29 @@ func (c *managedIdentityCredential) GetCacheKey() string {
return fmt.Sprintf("azure|msi|%s", clientId)
}
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
// No need to lock here because the caller is responsible for thread safety
if c.credential == nil {
func (c *managedIdentityCredential) getCredential() (azcore.TokenCredential, error) {
credential := c.credValue.Load()
if credential == nil {
c.credLock.Lock()
defer c.credLock.Unlock()
var err error
c.credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
if err != nil {
return nil, err
}
c.credValue.Store(credential)
}
return credential.(azcore.TokenCredential), nil
}
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
credential, err := c.getCredential()
if err != nil {
return nil, err
}
// Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource
@@ -135,7 +153,7 @@ func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes [
resource := strings.TrimSuffix(scopes[0], "/.default")
scopes = []string{resource}
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err
}
@@ -148,24 +166,40 @@ type clientSecretCredential struct {
tenantId string
clientId string
clientSecret string
credential azcore.TokenCredential
credLock sync.Mutex
credValue atomic.Value // of azcore.TokenCredential
}
func (c *clientSecretCredential) GetCacheKey() string {
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
}
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
// No need to lock here because the caller is responsible for thread safety
if c.credential == nil {
func (c *clientSecretCredential) getCredential() (azcore.TokenCredential, error) {
credential := c.credValue.Load()
if credential == nil {
c.credLock.Lock()
defer c.credLock.Unlock()
var err error
c.credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
if err != nil {
return nil, err
}
c.credValue.Store(credential)
}
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
return credential.(azcore.TokenCredential), nil
}
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
credential, err := c.getCredential()
if err != nil {
return nil, err
}
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,16 @@ import (
// interpolateString accepts template data and return a string with substitutions
func interpolateString(text string, data templateData) (string, error) {
t, err := template.New("content").Parse(text)
extraFuncs := map[string]interface{}{
"orEmpty": func(v interface{}) interface{} {
if v == nil {
return ""
}
return v
},
}
t, err := template.New("content").Funcs(extraFuncs).Parse(text)
if err != nil {
return "", fmt.Errorf("could not parse template %s", text)
}