From a337f70469e83556e452acc6a9ef83accf7453e4 Mon Sep 17 00:00:00 2001 From: Sergey Kostrukov Date: Mon, 24 May 2021 23:19:08 -0700 Subject: [PATCH] AzureMonitor: Fix Azure token provider national clouds (#34615) * Fix AAD authority for sovereign clouds * Update Azure SDK with scopes fix * Credential initialization in cache --- go.mod | 8 +- go.sum | 16 +- pkg/api/pluginproxy/token_cache.go | 65 +++++-- pkg/api/pluginproxy/token_cache_test.go | 183 +++++++++++++++++++- pkg/api/pluginproxy/token_provider_azure.go | 77 ++------ 5 files changed, 260 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 4f382baddf2..a227ca86274 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,8 @@ replace k8s.io/client-go => k8s.io/client-go v0.18.8 require ( cloud.google.com/go/storage v1.14.0 cuelang.org/go v0.3.2 - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1 github.com/BurntSushi/toml v0.3.1 github.com/Masterminds/semver v1.5.0 github.com/VividCortex/mysqlerr v0.0.0-20170204212430-6c6b55f8796f @@ -94,10 +94,10 @@ require ( go.opentelemetry.io/collector v0.25.0 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a golang.org/x/exp v0.0.0-20210220032938-85be41e4509f // indirect - golang.org/x/net v0.0.0-20210510120150-4163338589ed + golang.org/x/net v0.0.0-20210521195947-fe42d452be8f golang.org/x/oauth2 v0.0.0-20210413134643-5e61552d6c78 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect + golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba golang.org/x/tools v0.1.0 gonum.org/v1/gonum v0.9.1 diff --git a/go.sum b/go.sum index 7c57193f70b..ce078b50bf4 100644 --- a/go.sum +++ b/go.sum @@ -78,10 +78,10 @@ github.com/Azure/azure-sdk-for-go v51.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9mo github.com/Azure/azure-sdk-for-go v52.5.0+incompatible h1:/NLBWHCnIHtZyLPc1P7WIqi4Te4CC23kIQyK3Ep/7lA= github.com/Azure/azure-sdk-for-go v52.5.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0/go.mod h1:pElNP+u99BvCZD+0jOlhI9OC/NB2IDTOTGZOZH0Qhq8= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0 h1:ZsS7JltN+5D42mcU3Mb4lwVivlFL89v+FlXXMXE2YEM= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0 h1:wb00szFWtKeIef2Q5X8gdd0mYp8oSHmJOYUh/QXD8sw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1 h1:yQw8Ah26gBP4dv66ZNjZpRBRV+gaHH/0TLn1taU4FZ4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1 h1:KchdKK3XlOjkzBROV+q3D+YgfRTvwoeBwbaoX4aVkjI= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE= github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 h1:vx8McI56N5oLSQu8xa+xdiE0fjQq8W8Zt49vHP8Rygw= github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM= @@ -2092,8 +2092,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210324051636-2c4c8ecb7826/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= -golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I= -golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210521195947-fe42d452be8f h1:Si4U+UcgJzya9kpiEUJKQvjr512OLli+gL4poHrz93U= +golang.org/x/net v0.0.0-20210521195947-fe42d452be8f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -2240,8 +2240,8 @@ golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210412220455-f1c623a9e750/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E= -golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 h1:lCnv+lfrU9FRPGf8NeRuWAAPjNnema5WtBinMgs1fD8= +golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/api/pluginproxy/token_cache.go b/pkg/api/pluginproxy/token_cache.go index 91a3d92d361..a8f35867a75 100644 --- a/pkg/api/pluginproxy/token_cache.go +++ b/pkg/api/pluginproxy/token_cache.go @@ -5,6 +5,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" ) @@ -15,6 +16,7 @@ type AccessToken struct { type TokenCredential interface { GetCacheKey() string + Init() error GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) } @@ -31,7 +33,10 @@ type tokenCacheImpl struct { } type credentialCacheEntry struct { credential TokenCredential - cache sync.Map // of *scopesCacheEntry + + credInit uint32 + credMutex sync.Mutex + cache sync.Map // of *scopesCacheEntry } type scopesCacheEntry struct { @@ -44,31 +49,67 @@ type scopesCacheEntry struct { } func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) { + return c.getEntryFor(credential).getAccessToken(ctx, scopes) +} + +func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry { var entry interface{} var ok bool - credentialKey := credential.GetCacheKey() - scopesKey := getKeyForScopes(scopes) + key := credential.GetCacheKey() - if entry, ok = c.cache.Load(credentialKey); !ok { - entry, _ = c.cache.LoadOrStore(credentialKey, &credentialCacheEntry{ + if entry, ok = c.cache.Load(key); !ok { + entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{ credential: credential, }) } - credentialEntry := entry.(*credentialCacheEntry) + return entry.(*credentialCacheEntry) +} - if entry, ok = credentialEntry.cache.Load(scopesKey); !ok { - entry, _ = credentialEntry.cache.LoadOrStore(scopesKey, &scopesCacheEntry{ - credential: credentialEntry.credential, +func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) { + err := c.ensureInitialized() + if err != nil { + return "", err + } + + return c.getEntryFor(scopes).getAccessToken(ctx) +} + +func (c *credentialCacheEntry) ensureInitialized() error { + if atomic.LoadUint32(&c.credInit) == 0 { + c.credMutex.Lock() + defer c.credMutex.Unlock() + + if c.credInit == 0 { + // Initialize credential + err := c.credential.Init() + if err != nil { + return err + } + + atomic.StoreUint32(&c.credInit, 1) + } + } + + return nil +} + +func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry { + var entry interface{} + var ok bool + + key := getKeyForScopes(scopes) + + if entry, ok = c.cache.Load(key); !ok { + entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{ + credential: c.credential, scopes: scopes, cond: sync.NewCond(&sync.Mutex{}), }) } - scopesEntry := entry.(*scopesCacheEntry) - - return scopesEntry.getAccessToken(ctx) + return entry.(*scopesCacheEntry) } func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) { diff --git a/pkg/api/pluginproxy/token_cache_test.go b/pkg/api/pluginproxy/token_cache_test.go index 904e64d39b8..4edd67db9ef 100644 --- a/pkg/api/pluginproxy/token_cache_test.go +++ b/pkg/api/pluginproxy/token_cache_test.go @@ -14,7 +14,9 @@ import ( type fakeCredential struct { key string + initCalledTimes int calledTimes int + initFunc func() error getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error) } @@ -22,6 +24,19 @@ func (c *fakeCredential) GetCacheKey() string { return c.key } +func (c *fakeCredential) Reset() { + c.initCalledTimes = 0 + c.calledTimes = 0 +} + +func (c *fakeCredential) Init() error { + c.initCalledTimes = c.initCalledTimes + 1 + if c.initFunc != nil { + return c.initFunc() + } + return nil +} + func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { c.calledTimes = c.calledTimes + 1 if c.getAccessTokenFunc != nil { @@ -103,12 +118,168 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { }) } +func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { + t.Run("when credential init returns error", func(t *testing.T) { + credential := &fakeCredential{ + initFunc: func() error { + return errors.New("unable to initialize") + }, + } + + t.Run("should return error", func(t *testing.T) { + cacheEntry := &credentialCacheEntry{ + credential: credential, + } + + err := cacheEntry.ensureInitialized() + + assert.Error(t, err) + }) + + t.Run("should call init again each time and return error", func(t *testing.T) { + credential.Reset() + + cacheEntry := &credentialCacheEntry{ + credential: credential, + } + + var err error + err = cacheEntry.ensureInitialized() + assert.Error(t, err) + + err = cacheEntry.ensureInitialized() + assert.Error(t, err) + + err = cacheEntry.ensureInitialized() + assert.Error(t, err) + + assert.Equal(t, 3, credential.initCalledTimes) + }) + }) + + t.Run("when credential init returns error only once", func(t *testing.T) { + var times = 0 + credential := &fakeCredential{ + initFunc: func() error { + times = times + 1 + if times == 1 { + return errors.New("unable to initialize") + } + return nil + }, + } + + t.Run("should call credential init again only while it returns error", func(t *testing.T) { + cacheEntry := &credentialCacheEntry{ + credential: credential, + } + + var err error + err = cacheEntry.ensureInitialized() + assert.Error(t, err) + + err = cacheEntry.ensureInitialized() + assert.NoError(t, err) + + err = cacheEntry.ensureInitialized() + assert.NoError(t, err) + + assert.Equal(t, 2, credential.initCalledTimes) + }) + }) + + t.Run("when credential init panics", func(t *testing.T) { + credential := &fakeCredential{ + initFunc: func() error { + panic(errors.New("unable to initialize")) + }, + } + + t.Run("should call credential init again each time", func(t *testing.T) { + credential.Reset() + + cacheEntry := &credentialCacheEntry{ + credential: credential, + } + + func() { + defer func() { + assert.NotNil(t, recover(), "credential expected to panic") + }() + _ = cacheEntry.ensureInitialized() + }() + + func() { + defer func() { + assert.NotNil(t, recover(), "credential expected to panic") + }() + _ = cacheEntry.ensureInitialized() + }() + + func() { + defer func() { + assert.NotNil(t, recover(), "credential expected to panic") + }() + _ = cacheEntry.ensureInitialized() + }() + + assert.Equal(t, 3, credential.initCalledTimes) + }) + }) + + t.Run("when credential init panics only once", func(t *testing.T) { + var times = 0 + credential := &fakeCredential{ + initFunc: func() error { + times = times + 1 + if times == 1 { + panic(errors.New("unable to initialize")) + } + return nil + }, + } + + t.Run("should call credential init again only while it panics", func(t *testing.T) { + cacheEntry := &credentialCacheEntry{ + credential: credential, + } + + var err error + + func() { + defer func() { + assert.NotNil(t, recover(), "credential expected to panic") + }() + _ = cacheEntry.ensureInitialized() + }() + + func() { + defer func() { + assert.Nil(t, recover(), "credential not expected to panic") + }() + err = cacheEntry.ensureInitialized() + assert.NoError(t, err) + }() + + func() { + defer func() { + assert.Nil(t, recover(), "credential not expected to panic") + }() + err = cacheEntry.ensureInitialized() + assert.NoError(t, err) + }() + + assert.Equal(t, 2, credential.initCalledTimes) + }) + }) +} + func TestScopesCacheEntry_GetAccessToken(t *testing.T) { ctx := context.Background() scopes := []string{"Scope1"} - t.Run("when credential returns error", func(t *testing.T) { + t.Run("when credential getAccessToken returns error", func(t *testing.T) { credential := &fakeCredential{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)} @@ -130,7 +301,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }) t.Run("should call credential again each time and return error", func(t *testing.T) { - credential.calledTimes = 0 + credential.Reset() cacheEntry := &scopesCacheEntry{ credential: credential, @@ -152,7 +323,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }) }) - t.Run("when credential returns error only once", func(t *testing.T) { + t.Run("when credential getAccessToken returns error only once", func(t *testing.T) { var times = 0 credential := &fakeCredential{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { @@ -191,7 +362,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }) }) - t.Run("when credential panics", func(t *testing.T) { + t.Run("when credential getAccessToken panics", func(t *testing.T) { credential := &fakeCredential{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { panic(errors.New("unable to get access token")) @@ -199,7 +370,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { } t.Run("should call credential again each time", func(t *testing.T) { - credential.calledTimes = 0 + credential.Reset() cacheEntry := &scopesCacheEntry{ credential: credential, @@ -232,7 +403,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) { }) }) - t.Run("when credential panics only once", func(t *testing.T) { + t.Run("when credential getAccessToken panics only once", func(t *testing.T) { var times = 0 credential := &fakeCredential{ getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { diff --git a/pkg/api/pluginproxy/token_provider_azure.go b/pkg/api/pluginproxy/token_provider_azure.go index a72d9355b52..68fcae750f6 100644 --- a/pkg/api/pluginproxy/token_provider_azure.go +++ b/pkg/api/pluginproxy/token_provider_azure.go @@ -3,11 +3,8 @@ package pluginproxy import ( "context" "crypto/sha256" - "errors" "fmt" "strings" - "sync" - "sync/atomic" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -108,9 +105,8 @@ func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string) } type managedIdentityCredential struct { - clientId string - credLock sync.Mutex - credValue atomic.Value // of azcore.TokenCredential + clientId string + credential azcore.TokenCredential } func (c *managedIdentityCredential) GetCacheKey() string { @@ -121,39 +117,17 @@ func (c *managedIdentityCredential) GetCacheKey() string { return fmt.Sprintf("azure|msi|%s", clientId) } -func (c *managedIdentityCredential) getCredential() (azcore.TokenCredential, error) { - credential := c.credValue.Load() - - if credential == nil { - c.credLock.Lock() - defer c.credLock.Unlock() - - var err error - credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil) - if err != nil { - return nil, err - } - - c.credValue.Store(credential) +func (c *managedIdentityCredential) Init() error { + if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil { + return err + } else { + c.credential = credential + return nil } - - return credential.(azcore.TokenCredential), nil } func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { - credential, err := c.getCredential() - if err != nil { - return nil, err - } - - // Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource - if len(scopes) == 0 { - return nil, errors.New("scopes not provided") - } - resource := strings.TrimSuffix(scopes[0], "/.default") - scopes = []string{resource} - - accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) + accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) if err != nil { return nil, err } @@ -166,40 +140,25 @@ type clientSecretCredential struct { tenantId string clientId string clientSecret string - credLock sync.Mutex - credValue atomic.Value // of azcore.TokenCredential + credential azcore.TokenCredential } func (c *clientSecretCredential) GetCacheKey() string { return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret)) } -func (c *clientSecretCredential) getCredential() (azcore.TokenCredential, error) { - credential := c.credValue.Load() - - if credential == nil { - c.credLock.Lock() - defer c.credLock.Unlock() - - var err error - credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil) - if err != nil { - return nil, err - } - - c.credValue.Store(credential) +func (c *clientSecretCredential) Init() error { + options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority} + if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil { + return err + } else { + c.credential = credential + return nil } - - return credential.(azcore.TokenCredential), nil } func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { - credential, err := c.getCredential() - if err != nil { - return nil, err - } - - accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) + accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes}) if err != nil { return nil, err }