AzureMonitor: Fix Azure token provider national clouds (#34615)

* Fix AAD authority for sovereign clouds

* Update Azure SDK with scopes fix

* Credential initialization in cache
This commit is contained in:
Sergey Kostrukov 2021-05-24 23:19:08 -07:00 committed by GitHub
parent 9b518669dd
commit a337f70469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 260 additions and 89 deletions

8
go.mod
View File

@ -14,8 +14,8 @@ replace k8s.io/client-go => k8s.io/client-go v0.18.8
require (
cloud.google.com/go/storage v1.14.0
cuelang.org/go v0.3.2
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1
github.com/BurntSushi/toml v0.3.1
github.com/Masterminds/semver v1.5.0
github.com/VividCortex/mysqlerr v0.0.0-20170204212430-6c6b55f8796f
@ -94,10 +94,10 @@ require (
go.opentelemetry.io/collector v0.25.0
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
golang.org/x/exp v0.0.0-20210220032938-85be41e4509f // indirect
golang.org/x/net v0.0.0-20210510120150-4163338589ed
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f
golang.org/x/oauth2 v0.0.0-20210413134643-5e61552d6c78
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
golang.org/x/tools v0.1.0
gonum.org/v1/gonum v0.9.1

16
go.sum
View File

@ -78,10 +78,10 @@ github.com/Azure/azure-sdk-for-go v51.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9mo
github.com/Azure/azure-sdk-for-go v52.5.0+incompatible h1:/NLBWHCnIHtZyLPc1P7WIqi4Te4CC23kIQyK3Ep/7lA=
github.com/Azure/azure-sdk-for-go v52.5.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0/go.mod h1:pElNP+u99BvCZD+0jOlhI9OC/NB2IDTOTGZOZH0Qhq8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0 h1:ZsS7JltN+5D42mcU3Mb4lwVivlFL89v+FlXXMXE2YEM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0 h1:wb00szFWtKeIef2Q5X8gdd0mYp8oSHmJOYUh/QXD8sw=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1 h1:yQw8Ah26gBP4dv66ZNjZpRBRV+gaHH/0TLn1taU4FZ4=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1 h1:KchdKK3XlOjkzBROV+q3D+YgfRTvwoeBwbaoX4aVkjI=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE=
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM=
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 h1:vx8McI56N5oLSQu8xa+xdiE0fjQq8W8Zt49vHP8Rygw=
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM=
@ -2092,8 +2092,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210324051636-2c4c8ecb7826/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I=
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f h1:Si4U+UcgJzya9kpiEUJKQvjr512OLli+gL4poHrz93U=
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -2240,8 +2240,8 @@ golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210412220455-f1c623a9e750/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 h1:lCnv+lfrU9FRPGf8NeRuWAAPjNnema5WtBinMgs1fD8=
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -5,6 +5,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -15,6 +16,7 @@ type AccessToken struct {
type TokenCredential interface {
GetCacheKey() string
Init() error
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
}
@ -31,7 +33,10 @@ type tokenCacheImpl struct {
}
type credentialCacheEntry struct {
credential TokenCredential
cache sync.Map // of *scopesCacheEntry
credInit uint32
credMutex sync.Mutex
cache sync.Map // of *scopesCacheEntry
}
type scopesCacheEntry struct {
@ -44,31 +49,67 @@ type scopesCacheEntry struct {
}
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
}
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
var entry interface{}
var ok bool
credentialKey := credential.GetCacheKey()
scopesKey := getKeyForScopes(scopes)
key := credential.GetCacheKey()
if entry, ok = c.cache.Load(credentialKey); !ok {
entry, _ = c.cache.LoadOrStore(credentialKey, &credentialCacheEntry{
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
credential: credential,
})
}
credentialEntry := entry.(*credentialCacheEntry)
return entry.(*credentialCacheEntry)
}
if entry, ok = credentialEntry.cache.Load(scopesKey); !ok {
entry, _ = credentialEntry.cache.LoadOrStore(scopesKey, &scopesCacheEntry{
credential: credentialEntry.credential,
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
err := c.ensureInitialized()
if err != nil {
return "", err
}
return c.getEntryFor(scopes).getAccessToken(ctx)
}
func (c *credentialCacheEntry) ensureInitialized() error {
if atomic.LoadUint32(&c.credInit) == 0 {
c.credMutex.Lock()
defer c.credMutex.Unlock()
if c.credInit == 0 {
// Initialize credential
err := c.credential.Init()
if err != nil {
return err
}
atomic.StoreUint32(&c.credInit, 1)
}
}
return nil
}
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
var entry interface{}
var ok bool
key := getKeyForScopes(scopes)
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
credential: c.credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
})
}
scopesEntry := entry.(*scopesCacheEntry)
return scopesEntry.getAccessToken(ctx)
return entry.(*scopesCacheEntry)
}
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {

View File

@ -14,7 +14,9 @@ import (
type fakeCredential struct {
key string
initCalledTimes int
calledTimes int
initFunc func() error
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
@ -22,6 +24,19 @@ func (c *fakeCredential) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeCredential) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
}
return nil
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
@ -103,12 +118,168 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
})
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when credential init returns error", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
return errors.New("unable to initialize")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
err := cacheEntry.ensureInitialized()
assert.Error(t, err)
})
t.Run("should call init again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
return errors.New("unable to initialize")
}
return nil
},
}
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, credential.initCalledTimes)
})
})
t.Run("when credential init panics", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call credential init again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
panic(errors.New("unable to initialize"))
}
return nil
},
}
t.Run("should call credential init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, credential.initCalledTimes)
})
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential returns error", func(t *testing.T) {
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
@ -130,7 +301,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.calledTimes = 0
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
@ -152,7 +323,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential returns error only once", func(t *testing.T) {
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
@ -191,7 +362,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential panics", func(t *testing.T) {
t.Run("when credential getAccessToken panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
@ -199,7 +370,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.calledTimes = 0
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
@ -232,7 +403,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential panics only once", func(t *testing.T) {
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {

View File

@ -3,11 +3,8 @@ package pluginproxy
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@ -108,9 +105,8 @@ func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string)
}
type managedIdentityCredential struct {
clientId string
credLock sync.Mutex
credValue atomic.Value // of azcore.TokenCredential
clientId string
credential azcore.TokenCredential
}
func (c *managedIdentityCredential) GetCacheKey() string {
@ -121,39 +117,17 @@ func (c *managedIdentityCredential) GetCacheKey() string {
return fmt.Sprintf("azure|msi|%s", clientId)
}
func (c *managedIdentityCredential) getCredential() (azcore.TokenCredential, error) {
credential := c.credValue.Load()
if credential == nil {
c.credLock.Lock()
defer c.credLock.Unlock()
var err error
credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
if err != nil {
return nil, err
}
c.credValue.Store(credential)
func (c *managedIdentityCredential) Init() error {
if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil {
return err
} else {
c.credential = credential
return nil
}
return credential.(azcore.TokenCredential), nil
}
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
credential, err := c.getCredential()
if err != nil {
return nil, err
}
// Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource
if len(scopes) == 0 {
return nil, errors.New("scopes not provided")
}
resource := strings.TrimSuffix(scopes[0], "/.default")
scopes = []string{resource}
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err
}
@ -166,40 +140,25 @@ type clientSecretCredential struct {
tenantId string
clientId string
clientSecret string
credLock sync.Mutex
credValue atomic.Value // of azcore.TokenCredential
credential azcore.TokenCredential
}
func (c *clientSecretCredential) GetCacheKey() string {
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
}
func (c *clientSecretCredential) getCredential() (azcore.TokenCredential, error) {
credential := c.credValue.Load()
if credential == nil {
c.credLock.Lock()
defer c.credLock.Unlock()
var err error
credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
if err != nil {
return nil, err
}
c.credValue.Store(credential)
func (c *clientSecretCredential) Init() error {
options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority}
if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil {
return err
} else {
c.credential = credential
return nil
}
return credential.(azcore.TokenCredential), nil
}
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
credential, err := c.getCredential()
if err != nil {
return nil, err
}
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
if err != nil {
return nil, err
}