grafana/pkg/tsdb/azuremonitor/aztokenprovider/token_cache_test.go

458 lines
12 KiB
Go

package aztokenprovider
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeRetriever struct {
key string
initCalledTimes int
calledTimes int
initFunc func() error
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
func (c *fakeRetriever) GetCacheKey() string {
return c.key
}
func (c *fakeRetriever) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeRetriever) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
}
return nil
}
func (c *fakeRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
return c.getAccessTokenFunc(ctx, scopes)
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
}
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes1 := []string{"Scope1"}
scopes2 := []string{"Scope2"}
t.Run("should request access token from retriever", func(t *testing.T) {
cache := NewConcurrentTokenCache()
tokenRetriever := &fakeRetriever{key: "retriever"}
token, err := cache.GetAccessToken(ctx, tokenRetriever, scopes1)
require.NoError(t, err)
assert.Equal(t, "retriever-token-1", token)
assert.Equal(t, 1, tokenRetriever.calledTimes)
})
t.Run("should return cached token for same scopes", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential := &fakeRetriever{key: "credential-1"}
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
assert.Equal(t, 2, credential.calledTimes)
})
t.Run("should return cached token for same credentials", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential1 := &fakeRetriever{key: "credential-1"}
credential2 := &fakeRetriever{key: "credential-2"}
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
assert.Equal(t, 1, credential1.calledTimes)
assert.Equal(t, 1, credential2.calledTimes)
})
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when retriever init returns error", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
initFunc: func() error {
return errors.New("unable to initialize")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
retriever: tokenRetriever,
}
err := cacheEntry.ensureInitialized()
assert.Error(t, err)
})
t.Run("should call init again each time and return error", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &credentialCacheEntry{
retriever: tokenRetriever,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
})
})
t.Run("when retriever init returns error only once", func(t *testing.T) {
var times = 0
tokenRetriever := &fakeRetriever{
initFunc: func() error {
times = times + 1
if times == 1 {
return errors.New("unable to initialize")
}
return nil
},
}
t.Run("should call retriever init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
retriever: tokenRetriever,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
})
})
t.Run("when retriever init panics", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call retriever init again each time", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &credentialCacheEntry{
retriever: tokenRetriever,
}
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
})
})
t.Run("when retriever init panics only once", func(t *testing.T) {
var times = 0
tokenRetriever := &fakeRetriever{
initFunc: func() error {
times = times + 1
if times == 1 {
panic(errors.New("unable to initialize"))
}
return nil
},
}
t.Run("should call retriever init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
retriever: tokenRetriever,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "retriever not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
func() {
defer func() {
assert.Nil(t, recover(), "retriever not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
})
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when retriever getAccessToken returns error", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, "", accessToken)
})
t.Run("should call retriever again each time and return error", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &scopesCacheEntry{
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, tokenRetriever.calledTimes)
})
})
t.Run("when retriever getAccessToken returns error only once", func(t *testing.T) {
var times = 0
retriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call retriever again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
retriever: retriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, retriever.calledTimes)
})
})
t.Run("when retriever getAccessToken panics", func(t *testing.T) {
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call retriever again each time", func(t *testing.T) {
tokenRetriever.Reset()
cacheEntry := &scopesCacheEntry{
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, tokenRetriever.calledTimes)
})
})
t.Run("when retriever getAccessToken panics only once", func(t *testing.T) {
var times = 0
tokenRetriever := &fakeRetriever{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
panic(errors.New("unable to get access token"))
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call retriever again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
retriever: tokenRetriever,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "retriever expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "retriever not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
func() {
defer func() {
assert.Nil(t, recover(), "retriever not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, tokenRetriever.calledTimes)
})
})
}