mirror of
https://github.com/grafana/grafana.git
synced 2025-02-11 16:15:42 -06:00
AzureMonitor: Fix Azure token provider national clouds (#34615)
* Fix AAD authority for sovereign clouds * Update Azure SDK with scopes fix * Credential initialization in cache
This commit is contained in:
parent
9b518669dd
commit
a337f70469
8
go.mod
8
go.mod
@ -14,8 +14,8 @@ replace k8s.io/client-go => k8s.io/client-go v0.18.8
|
||||
require (
|
||||
cloud.google.com/go/storage v1.14.0
|
||||
cuelang.org/go v0.3.2
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1
|
||||
github.com/BurntSushi/toml v0.3.1
|
||||
github.com/Masterminds/semver v1.5.0
|
||||
github.com/VividCortex/mysqlerr v0.0.0-20170204212430-6c6b55f8796f
|
||||
@ -94,10 +94,10 @@ require (
|
||||
go.opentelemetry.io/collector v0.25.0
|
||||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
|
||||
golang.org/x/exp v0.0.0-20210220032938-85be41e4509f // indirect
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed
|
||||
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f
|
||||
golang.org/x/oauth2 v0.0.0-20210413134643-5e61552d6c78
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect
|
||||
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 // indirect
|
||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
|
||||
golang.org/x/tools v0.1.0
|
||||
gonum.org/v1/gonum v0.9.1
|
||||
|
16
go.sum
16
go.sum
@ -78,10 +78,10 @@ github.com/Azure/azure-sdk-for-go v51.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9mo
|
||||
github.com/Azure/azure-sdk-for-go v52.5.0+incompatible h1:/NLBWHCnIHtZyLPc1P7WIqi4Te4CC23kIQyK3Ep/7lA=
|
||||
github.com/Azure/azure-sdk-for-go v52.5.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.14.0/go.mod h1:pElNP+u99BvCZD+0jOlhI9OC/NB2IDTOTGZOZH0Qhq8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0 h1:ZsS7JltN+5D42mcU3Mb4lwVivlFL89v+FlXXMXE2YEM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.0/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0 h1:wb00szFWtKeIef2Q5X8gdd0mYp8oSHmJOYUh/QXD8sw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.8.0/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1 h1:yQw8Ah26gBP4dv66ZNjZpRBRV+gaHH/0TLn1taU4FZ4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.16.1/go.mod h1:MVdrcUC4Hup35qHym3VdzoW+NBgBxrta9Vei97jRtM8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1 h1:KchdKK3XlOjkzBROV+q3D+YgfRTvwoeBwbaoX4aVkjI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.9.1/go.mod h1:acANgl9stsT5xflESXKjZx4rhZJSr0TGgTDYY0xJPIE=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.0/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1 h1:vx8McI56N5oLSQu8xa+xdiE0fjQq8W8Zt49vHP8Rygw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v0.5.1/go.mod h1:k4KbFSunV/+0hOHL1vyFaPsiYQ1Vmvy1TBpmtvCDLZM=
|
||||
@ -2092,8 +2092,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
|
||||
golang.org/x/net v0.0.0-20210324051636-2c4c8ecb7826/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
|
||||
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I=
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f h1:Si4U+UcgJzya9kpiEUJKQvjr512OLli+gL4poHrz93U=
|
||||
golang.org/x/net v0.0.0-20210521195947-fe42d452be8f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@ -2240,8 +2240,8 @@ golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210412220455-f1c623a9e750/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E=
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1 h1:lCnv+lfrU9FRPGf8NeRuWAAPjNnema5WtBinMgs1fD8=
|
||||
golang.org/x/sys v0.0.0-20210521203332-0cec03c779c1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -15,6 +16,7 @@ type AccessToken struct {
|
||||
|
||||
type TokenCredential interface {
|
||||
GetCacheKey() string
|
||||
Init() error
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
@ -31,7 +33,10 @@ type tokenCacheImpl struct {
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
credential TokenCredential
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
|
||||
credInit uint32
|
||||
credMutex sync.Mutex
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
@ -44,31 +49,67 @@ type scopesCacheEntry struct {
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
|
||||
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
credentialKey := credential.GetCacheKey()
|
||||
scopesKey := getKeyForScopes(scopes)
|
||||
key := credential.GetCacheKey()
|
||||
|
||||
if entry, ok = c.cache.Load(credentialKey); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(credentialKey, &credentialCacheEntry{
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
|
||||
credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
credentialEntry := entry.(*credentialCacheEntry)
|
||||
return entry.(*credentialCacheEntry)
|
||||
}
|
||||
|
||||
if entry, ok = credentialEntry.cache.Load(scopesKey); !ok {
|
||||
entry, _ = credentialEntry.cache.LoadOrStore(scopesKey, &scopesCacheEntry{
|
||||
credential: credentialEntry.credential,
|
||||
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
err := c.ensureInitialized()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.getEntryFor(scopes).getAccessToken(ctx)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) ensureInitialized() error {
|
||||
if atomic.LoadUint32(&c.credInit) == 0 {
|
||||
c.credMutex.Lock()
|
||||
defer c.credMutex.Unlock()
|
||||
|
||||
if c.credInit == 0 {
|
||||
// Initialize credential
|
||||
err := c.credential.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.credInit, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := getKeyForScopes(scopes)
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
|
||||
credential: c.credential,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
scopesEntry := entry.(*scopesCacheEntry)
|
||||
|
||||
return scopesEntry.getAccessToken(ctx)
|
||||
return entry.(*scopesCacheEntry)
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
|
@ -14,7 +14,9 @@ import (
|
||||
|
||||
type fakeCredential struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
initFunc func() error
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
@ -22,6 +24,19 @@ func (c *fakeCredential) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
@ -103,12 +118,168 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when credential init returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
return errors.New("unable to initialize")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to initialize"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential returns error", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
@ -130,7 +301,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
@ -152,7 +323,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential returns error only once", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
@ -191,7 +362,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
@ -199,7 +370,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
@ -232,7 +403,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics only once", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
|
@ -3,11 +3,8 @@ package pluginproxy
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
@ -108,9 +105,8 @@ func (provider *azureAccessTokenProvider) resolveAuthorityHost(cloudName string)
|
||||
}
|
||||
|
||||
type managedIdentityCredential struct {
|
||||
clientId string
|
||||
credLock sync.Mutex
|
||||
credValue atomic.Value // of azcore.TokenCredential
|
||||
clientId string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
@ -121,39 +117,17 @@ func (c *managedIdentityCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|msi|%s", clientId)
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) getCredential() (azcore.TokenCredential, error) {
|
||||
credential := c.credValue.Load()
|
||||
|
||||
if credential == nil {
|
||||
c.credLock.Lock()
|
||||
defer c.credLock.Unlock()
|
||||
|
||||
var err error
|
||||
credential, err = azidentity.NewManagedIdentityCredential(c.clientId, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.credValue.Store(credential)
|
||||
func (c *managedIdentityCredential) Init() error {
|
||||
if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.credential = credential
|
||||
return nil
|
||||
}
|
||||
|
||||
return credential.(azcore.TokenCredential), nil
|
||||
}
|
||||
|
||||
func (c *managedIdentityCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
credential, err := c.getCredential()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Implementation of ManagedIdentityCredential doesn't support scopes, converting to resource
|
||||
if len(scopes) == 0 {
|
||||
return nil, errors.New("scopes not provided")
|
||||
}
|
||||
resource := strings.TrimSuffix(scopes[0], "/.default")
|
||||
scopes = []string{resource}
|
||||
|
||||
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -166,40 +140,25 @@ type clientSecretCredential struct {
|
||||
tenantId string
|
||||
clientId string
|
||||
clientSecret string
|
||||
credLock sync.Mutex
|
||||
credValue atomic.Value // of azcore.TokenCredential
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) getCredential() (azcore.TokenCredential, error) {
|
||||
credential := c.credValue.Load()
|
||||
|
||||
if credential == nil {
|
||||
c.credLock.Lock()
|
||||
defer c.credLock.Unlock()
|
||||
|
||||
var err error
|
||||
credential, err = azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.credValue.Store(credential)
|
||||
func (c *clientSecretCredential) Init() error {
|
||||
options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority}
|
||||
if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.credential = credential
|
||||
return nil
|
||||
}
|
||||
|
||||
return credential.(azcore.TokenCredential), nil
|
||||
}
|
||||
|
||||
func (c *clientSecretCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
credential, err := c.getCredential()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
accessToken, err := c.credential.GetToken(ctx, azcore.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user