mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Migrate to Grafana Azure SDK (#47232)
This commit is contained in:
parent
a55274a72d
commit
5675496f6b
13
go.mod
13
go.mod
@ -17,8 +17,6 @@ require (
|
||||
cloud.google.com/go/storage v1.18.2
|
||||
cuelang.org/go v0.4.0
|
||||
github.com/Azure/azure-sdk-for-go v59.3.0+incompatible
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.10.0
|
||||
github.com/Azure/go-autorest/autorest v0.11.22
|
||||
github.com/BurntSushi/toml v0.3.1
|
||||
github.com/Masterminds/semver v1.5.0
|
||||
@ -54,6 +52,7 @@ require (
|
||||
github.com/gosimple/slug v1.9.0
|
||||
github.com/grafana/cuetsy v0.0.0-20211119211437-8c25464cc9bf
|
||||
github.com/grafana/grafana-aws-sdk v0.10.1
|
||||
github.com/grafana/grafana-azure-sdk-go v1.0.0
|
||||
github.com/grafana/grafana-plugin-sdk-go v0.129.0
|
||||
github.com/grafana/loki v1.6.2-0.20211015002020-7832783b1caa
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
@ -183,7 +182,7 @@ require (
|
||||
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/flatbuffers v2.0.0+incompatible // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.1.1
|
||||
github.com/googleapis/gax-go/v2 v2.1.1 // indirect
|
||||
github.com/gorilla/mux v1.8.0 // indirect
|
||||
github.com/grafana/grafana-google-sdk-go v0.0.0-20211104130251-b190293eaf58
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect
|
||||
@ -247,19 +246,20 @@ require (
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1
|
||||
google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1 // indirect
|
||||
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/kms v1.1.0
|
||||
github.com/Azure/go-autorest/autorest/adal v0.9.17
|
||||
github.com/golang-migrate/migrate/v4 v4.7.0
|
||||
github.com/grafana/dskit v0.0.0-20211011144203-3a88ec0b675f
|
||||
gocloud.dev v0.24.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.10.0 // indirect
|
||||
github.com/Azure/go-autorest/autorest/adal v0.9.17 // indirect
|
||||
github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect
|
||||
github.com/chromedp/cdproto v0.0.0-20220208224320-6efb837e6bc2 // indirect
|
||||
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 // indirect
|
||||
@ -270,6 +270,7 @@ require (
|
||||
github.com/envoyproxy/protoc-gen-validate v0.6.2 // indirect
|
||||
github.com/getkin/kin-openapi v0.91.0 // indirect
|
||||
github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 // indirect
|
||||
github.com/grafana/dskit v0.0.0-20211011144203-3a88ec0b675f // indirect
|
||||
github.com/imdario/mergo v0.3.12 // indirect
|
||||
github.com/klauspost/compress v1.13.6 // indirect
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -1335,6 +1335,8 @@ github.com/grafana/go-mssqldb v0.0.0-20210326084033-d0ce3c521036 h1:GplhUk6Xes5J
|
||||
github.com/grafana/go-mssqldb v0.0.0-20210326084033-d0ce3c521036/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
|
||||
github.com/grafana/grafana-aws-sdk v0.10.1 h1:Ksguhjx6EuGLN/5Oc7oZoxuDReJ5RxIH99yqSMpLGUs=
|
||||
github.com/grafana/grafana-aws-sdk v0.10.1/go.mod h1:vFIOHEnY1u5nY0/tge1IHQjPuG6DRKr2ISf/HikUdjE=
|
||||
github.com/grafana/grafana-azure-sdk-go v1.0.0 h1:RIVQyVb89/y/BnOVsVDcxiMtmWF8NmAX8ql0OJvzwNc=
|
||||
github.com/grafana/grafana-azure-sdk-go v1.0.0/go.mod h1:xbzMaG74BN4rOP1NYEsCMNWkPbK7GfSU09PGYfQYm+g=
|
||||
github.com/grafana/grafana-google-sdk-go v0.0.0-20211104130251-b190293eaf58 h1:2ud7NNM7LrGPO4x0NFR8qLq68CqI4SmB7I2yRN2w9oE=
|
||||
github.com/grafana/grafana-google-sdk-go v0.0.0-20211104130251-b190293eaf58/go.mod h1:Vo2TKWfDVmNTELBUM+3lkrZvFtBws0qSZdXhQxRdJrE=
|
||||
github.com/grafana/grafana-plugin-sdk-go v0.94.0/go.mod h1:3VXz4nCv6wH5SfgB3mlW39s+c+LetqSCjFj7xxPC5+M=
|
||||
@ -2462,8 +2464,6 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn
|
||||
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||
github.com/vectordotdev/go-datemath v0.1.1-0.20220110192739-f9ce83ec349f h1:2upw/ZfjkCKpc4k6DXg7lMfCSLkfw/8epV5/y2ZUQ8U=
|
||||
github.com/vectordotdev/go-datemath v0.1.1-0.20220110192739-f9ce83ec349f/go.mod h1:PnwzbSst7KD3vpBzzlntZU5gjVa455Uqa5QPiKSYJzQ=
|
||||
github.com/vectordotdev/go-datemath v0.1.1-0.20220323213446-f3954d0b18ae h1:oyiy3uBj1F4O3AaFh7hUGBrJjAssJhKyAbwxtkslxqo=
|
||||
github.com/vectordotdev/go-datemath v0.1.1-0.20220323213446-f3954d0b18ae/go.mod h1:PnwzbSst7KD3vpBzzlntZU5gjVa455Uqa5QPiKSYJzQ=
|
||||
github.com/vektah/gqlparser v1.1.2/go.mod h1:1ycwN7Ij5njmMkPPAOaRFY4rET2Enx7IkVv3vaXspKw=
|
||||
|
@ -4,11 +4,12 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-azure-sdk-go/aztokenprovider"
|
||||
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
type azureAccessTokenProvider struct {
|
||||
|
@ -1,8 +1,9 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
)
|
||||
|
||||
type Cfg struct {
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
|
@ -12,6 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azhttpclient"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -24,8 +26,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/secrets"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azhttpclient"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
|
@ -8,7 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
@ -22,7 +24,6 @@ import (
|
||||
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -20,8 +20,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-aws-sdk/pkg/awsds"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
|
@ -1,6 +1,6 @@
|
||||
package setting
|
||||
|
||||
import "github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
import "github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
func (cfg *Cfg) readAzureSettings() {
|
||||
azureSettings := &azsettings.AzureSettings{}
|
||||
|
@ -3,7 +3,8 @@ package setting
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -1,83 +0,0 @@
|
||||
package azcredentials
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func FromDatasourceData(data map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
|
||||
if credentialsObj, err := getMapOptional(data, "azureCredentials"); err != nil {
|
||||
return nil, err
|
||||
} else if credentialsObj == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return getFromCredentialsObject(credentialsObj, secureData)
|
||||
}
|
||||
}
|
||||
|
||||
func getFromCredentialsObject(credentialsObj map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
|
||||
authType, err := getStringValue(credentialsObj, "authType")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case AzureAuthManagedIdentity:
|
||||
credentials := &AzureManagedIdentityCredentials{}
|
||||
return credentials, nil
|
||||
|
||||
case AzureAuthClientSecret:
|
||||
cloud, err := getStringValue(credentialsObj, "azureCloud")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tenantId, err := getStringValue(credentialsObj, "tenantId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientId, err := getStringValue(credentialsObj, "clientId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSecret := secureData["azureClientSecret"]
|
||||
|
||||
credentials := &AzureClientSecretCredentials{
|
||||
AzureCloud: cloud,
|
||||
TenantId: tenantId,
|
||||
ClientId: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
}
|
||||
return credentials, nil
|
||||
|
||||
default:
|
||||
err := fmt.Errorf("the authentication type '%s' not supported", authType)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func getMapOptional(obj map[string]interface{}, key string) (map[string]interface{}, error) {
|
||||
if untypedValue, ok := obj[key]; ok {
|
||||
if value, ok := untypedValue.(map[string]interface{}); ok {
|
||||
return value, nil
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be an object", key)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Value optional, not error
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getStringValue(obj map[string]interface{}, key string) (string, error) {
|
||||
if untypedValue, ok := obj[key]; ok {
|
||||
if value, ok := untypedValue.(string); ok {
|
||||
return value, nil
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be a string", key)
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be set", key)
|
||||
return "", err
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package azcredentials
|
||||
|
||||
const (
|
||||
AzureAuthManagedIdentity = "msi"
|
||||
AzureAuthClientSecret = "clientsecret"
|
||||
)
|
||||
|
||||
type AzureCredentials interface {
|
||||
AzureAuthType() string
|
||||
}
|
||||
|
||||
type AzureManagedIdentityCredentials struct {
|
||||
ClientId string
|
||||
}
|
||||
|
||||
type AzureClientSecretCredentials struct {
|
||||
AzureCloud string
|
||||
Authority string
|
||||
TenantId string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
func (credentials *AzureManagedIdentityCredentials) AzureAuthType() string {
|
||||
return AzureAuthManagedIdentity
|
||||
}
|
||||
|
||||
func (credentials *AzureClientSecretCredentials) AzureAuthType() string {
|
||||
return AzureAuthClientSecret
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package azhttpclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
const azureMiddlewareName = "AzureAuthentication"
|
||||
|
||||
func AzureMiddleware(settings *azsettings.AzureSettings, credentials azcredentials.AzureCredentials, scopes []string) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(azureMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(settings, credentials)
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
return ApplyAzureAuth(tokenProvider, scopes, next)
|
||||
})
|
||||
}
|
||||
|
||||
func ApplyAzureAuth(tokenProvider aztokenprovider.AzureTokenProvider, scopes []string, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
|
||||
func errorResponse(err error) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("invalid Azure configuration: %s", err)
|
||||
})
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
package azhttpclient
|
||||
|
||||
import (
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
)
|
||||
|
||||
func AddAzureAuthentication(opts *sdkhttpclient.Options, settings *azsettings.AzureSettings, credentials azcredentials.AzureCredentials, scopes []string) {
|
||||
opts.Middlewares = append(opts.Middlewares, AzureMiddleware(settings, credentials, scopes))
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package azsettings
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
AzurePublic = "AzureCloud"
|
||||
AzureChina = "AzureChinaCloud"
|
||||
AzureUSGovernment = "AzureUSGovernment"
|
||||
AzureGermany = "AzureGermanCloud"
|
||||
)
|
||||
|
||||
func NormalizeAzureCloud(cloudName string) string {
|
||||
switch strings.ToLower(cloudName) {
|
||||
// Public
|
||||
case "azurecloud":
|
||||
fallthrough
|
||||
case "azurepublic":
|
||||
fallthrough
|
||||
case "azurepubliccloud":
|
||||
fallthrough
|
||||
case "public":
|
||||
return AzurePublic
|
||||
|
||||
// China
|
||||
case "azurechina":
|
||||
fallthrough
|
||||
case "azurechinacloud":
|
||||
fallthrough
|
||||
case "china":
|
||||
return AzureChina
|
||||
|
||||
// US Government
|
||||
case "azureusgovernment":
|
||||
fallthrough
|
||||
case "azureusgovernmentcloud":
|
||||
fallthrough
|
||||
case "usgov":
|
||||
fallthrough
|
||||
case "usgovernment":
|
||||
return AzureUSGovernment
|
||||
|
||||
// Germany
|
||||
case "azuregermancloud":
|
||||
fallthrough
|
||||
case "azuregermany":
|
||||
fallthrough
|
||||
case "german":
|
||||
fallthrough
|
||||
case "germany":
|
||||
return AzureGermany
|
||||
}
|
||||
|
||||
// Pass the name unchanged if it's not known
|
||||
return cloudName
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
package azsettings
|
||||
|
||||
type AzureSettings struct {
|
||||
Cloud string
|
||||
ManagedIdentityEnabled bool
|
||||
ManagedIdentityClientId string
|
||||
}
|
@ -1,189 +0,0 @@
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenRetriever interface {
|
||||
GetCacheKey() string
|
||||
Init() error
|
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
type ConcurrentTokenCache interface {
|
||||
GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache {
|
||||
return &tokenCacheImpl{}
|
||||
}
|
||||
|
||||
type tokenCacheImpl struct {
|
||||
cache sync.Map // of *credentialCacheEntry
|
||||
}
|
||||
type credentialCacheEntry struct {
|
||||
retriever TokenRetriever
|
||||
|
||||
credInit uint32
|
||||
credMutex sync.Mutex
|
||||
cache sync.Map // of *scopesCacheEntry
|
||||
}
|
||||
|
||||
type scopesCacheEntry struct {
|
||||
retriever TokenRetriever
|
||||
scopes []string
|
||||
|
||||
cond *sync.Cond
|
||||
refreshing bool
|
||||
accessToken *AccessToken
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, tokenRetriever TokenRetriever, scopes []string) (string, error) {
|
||||
return c.getEntryFor(tokenRetriever).getAccessToken(ctx, scopes)
|
||||
}
|
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenRetriever) *credentialCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := credential.GetCacheKey()
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
|
||||
retriever: credential,
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*credentialCacheEntry)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
err := c.ensureInitialized()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.getEntryFor(scopes).getAccessToken(ctx)
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) ensureInitialized() error {
|
||||
if atomic.LoadUint32(&c.credInit) == 0 {
|
||||
c.credMutex.Lock()
|
||||
defer c.credMutex.Unlock()
|
||||
|
||||
if c.credInit == 0 {
|
||||
// Initialize retriever
|
||||
err := c.retriever.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.credInit, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
|
||||
var entry interface{}
|
||||
var ok bool
|
||||
|
||||
key := getKeyForScopes(scopes)
|
||||
|
||||
if entry, ok = c.cache.Load(key); !ok {
|
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
|
||||
retriever: c.retriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
})
|
||||
}
|
||||
|
||||
return entry.(*scopesCacheEntry)
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
|
||||
var accessToken *AccessToken
|
||||
var err error
|
||||
shouldRefresh := false
|
||||
|
||||
c.cond.L.Lock()
|
||||
for {
|
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(timeNow().Add(2*time.Minute)) {
|
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken
|
||||
break
|
||||
}
|
||||
|
||||
if !c.refreshing {
|
||||
// Start refreshing the token
|
||||
c.refreshing = true
|
||||
shouldRefresh = true
|
||||
break
|
||||
}
|
||||
|
||||
// Wait for the token to be refreshed
|
||||
c.cond.Wait()
|
||||
}
|
||||
c.cond.L.Unlock()
|
||||
|
||||
if shouldRefresh {
|
||||
accessToken, err = c.refreshAccessToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken.Token, nil
|
||||
}
|
||||
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
|
||||
var accessToken *AccessToken
|
||||
|
||||
// Safeguarding from panic caused by retriever implementation
|
||||
defer func() {
|
||||
c.cond.L.Lock()
|
||||
|
||||
c.refreshing = false
|
||||
|
||||
if accessToken != nil {
|
||||
c.accessToken = accessToken
|
||||
}
|
||||
|
||||
c.cond.Broadcast()
|
||||
c.cond.L.Unlock()
|
||||
}()
|
||||
|
||||
token, err := c.retriever.GetAccessToken(ctx, c.scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken = token
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func getKeyForScopes(scopes []string) string {
|
||||
if len(scopes) > 1 {
|
||||
arr := make([]string, len(scopes))
|
||||
copy(arr, scopes)
|
||||
sort.Strings(arr)
|
||||
scopes = arr
|
||||
}
|
||||
|
||||
return strings.Join(scopes, " ")
|
||||
}
|
@ -1,457 +0,0 @@
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeRetriever struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
initFunc func() error
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
func (c *fakeRetriever) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeRetriever) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeRetriever) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
return c.getAccessTokenFunc(ctx, scopes)
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
}
|
||||
|
||||
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes1 := []string{"Scope1"}
|
||||
scopes2 := []string{"Scope2"}
|
||||
|
||||
t.Run("should request access token from retriever", func(t *testing.T) {
|
||||
cache := NewConcurrentTokenCache()
|
||||
tokenRetriever := &fakeRetriever{key: "retriever"}
|
||||
|
||||
token, err := cache.GetAccessToken(ctx, tokenRetriever, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "retriever-token-1", token)
|
||||
|
||||
assert.Equal(t, 1, tokenRetriever.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential := &fakeRetriever{key: "credential-1"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-2", token2)
|
||||
|
||||
assert.Equal(t, 2, credential.calledTimes)
|
||||
})
|
||||
|
||||
t.Run("should return cached token for same credentials", func(t *testing.T) {
|
||||
var token1, token2 string
|
||||
var err error
|
||||
|
||||
cache := NewConcurrentTokenCache()
|
||||
credential1 := &fakeRetriever{key: "credential-1"}
|
||||
credential2 := &fakeRetriever{key: "credential-2"}
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-1-token-1", token1)
|
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "credential-2-token-1", token2)
|
||||
|
||||
assert.Equal(t, 1, credential1.calledTimes)
|
||||
assert.Equal(t, 1, credential2.calledTimes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when retriever init returns error", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
return errors.New("unable to initialize")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever init panics", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever init again each time", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
tokenRetriever := &fakeRetriever{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to initialize"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, tokenRetriever.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when retriever getAccessToken returns error", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", accessToken)
|
||||
})
|
||||
|
||||
t.Run("should call retriever again each time and return error", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
retriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return invalidToken, errors.New("unable to get access token")
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
retriever: retriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
|
||||
assert.Equal(t, 2, retriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever getAccessToken panics", func(t *testing.T) {
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever again each time", func(t *testing.T) {
|
||||
tokenRetriever.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when retriever getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
tokenRetriever := &fakeRetriever{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to get access token"))
|
||||
}
|
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
|
||||
return fakeAccessToken, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call retriever again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
retriever: tokenRetriever,
|
||||
scopes: scopes,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
var accessToken string
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "retriever expected to panic")
|
||||
}()
|
||||
_, _ = cacheEntry.getAccessToken(ctx)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "retriever not expected to panic")
|
||||
}()
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "token-2", accessToken)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, tokenRetriever.calledTimes)
|
||||
})
|
||||
})
|
||||
}
|
@ -1,187 +0,0 @@
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
)
|
||||
|
||||
var (
|
||||
azureTokenCache = NewConcurrentTokenCache()
|
||||
)
|
||||
|
||||
type AzureTokenProvider interface {
|
||||
GetAccessToken(ctx context.Context, scopes []string) (string, error)
|
||||
}
|
||||
|
||||
type tokenProviderImpl struct {
|
||||
tokenRetriever TokenRetriever
|
||||
}
|
||||
|
||||
func NewAzureAccessTokenProvider(settings *azsettings.AzureSettings, credentials azcredentials.AzureCredentials) (AzureTokenProvider, error) {
|
||||
if settings == nil {
|
||||
err := fmt.Errorf("parameter 'settings' cannot be nil")
|
||||
return nil, err
|
||||
}
|
||||
if credentials == nil {
|
||||
err := fmt.Errorf("parameter 'credentials' cannot be nil")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenRetriever TokenRetriever
|
||||
|
||||
switch c := credentials.(type) {
|
||||
case *azcredentials.AzureManagedIdentityCredentials:
|
||||
if !settings.ManagedIdentityEnabled {
|
||||
err := fmt.Errorf("managed identity authentication is not enabled in Grafana config")
|
||||
return nil, err
|
||||
} else {
|
||||
tokenRetriever = getManagedIdentityTokenRetriever(settings, c)
|
||||
}
|
||||
case *azcredentials.AzureClientSecretCredentials:
|
||||
tokenRetriever = getClientSecretTokenRetriever(c)
|
||||
default:
|
||||
err := fmt.Errorf("credentials of type '%s' not supported by authentication provider", c.AzureAuthType())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenProvider := &tokenProviderImpl{
|
||||
tokenRetriever: tokenRetriever,
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
func (provider *tokenProviderImpl) GetAccessToken(ctx context.Context, scopes []string) (string, error) {
|
||||
if ctx == nil {
|
||||
err := fmt.Errorf("parameter 'ctx' cannot be nil")
|
||||
return "", err
|
||||
}
|
||||
if scopes == nil {
|
||||
err := fmt.Errorf("parameter 'scopes' cannot be nil")
|
||||
return "", err
|
||||
}
|
||||
|
||||
accessToken, err := azureTokenCache.GetAccessToken(ctx, provider.tokenRetriever, scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func getManagedIdentityTokenRetriever(settings *azsettings.AzureSettings, credentials *azcredentials.AzureManagedIdentityCredentials) TokenRetriever {
|
||||
var clientId string
|
||||
if credentials.ClientId != "" {
|
||||
clientId = credentials.ClientId
|
||||
} else {
|
||||
clientId = settings.ManagedIdentityClientId
|
||||
}
|
||||
return &managedIdentityTokenRetriever{
|
||||
clientId: clientId,
|
||||
}
|
||||
}
|
||||
|
||||
func getClientSecretTokenRetriever(credentials *azcredentials.AzureClientSecretCredentials) TokenRetriever {
|
||||
var authority string
|
||||
if credentials.Authority != "" {
|
||||
authority = credentials.Authority
|
||||
} else {
|
||||
authority = resolveAuthorityForCloud(credentials.AzureCloud)
|
||||
}
|
||||
return &clientSecretTokenRetriever{
|
||||
authority: authority,
|
||||
tenantId: credentials.TenantId,
|
||||
clientId: credentials.ClientId,
|
||||
clientSecret: credentials.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAuthorityForCloud(cloudName string) string {
|
||||
// Known Azure clouds
|
||||
switch cloudName {
|
||||
case azsettings.AzurePublic:
|
||||
return azidentity.AzurePublicCloud
|
||||
case azsettings.AzureChina:
|
||||
return azidentity.AzureChina
|
||||
case azsettings.AzureUSGovernment:
|
||||
return azidentity.AzureGovernment
|
||||
case azsettings.AzureGermany:
|
||||
return azidentity.AzureGermany
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type managedIdentityTokenRetriever struct {
|
||||
clientId string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *managedIdentityTokenRetriever) GetCacheKey() string {
|
||||
clientId := c.clientId
|
||||
if clientId == "" {
|
||||
clientId = "system"
|
||||
}
|
||||
return fmt.Sprintf("azure|msi|%s", clientId)
|
||||
}
|
||||
|
||||
func (c *managedIdentityTokenRetriever) Init() error {
|
||||
if credential, err := azidentity.NewManagedIdentityCredential(c.clientId, nil); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.credential = credential
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *managedIdentityTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
accessToken, err := c.credential.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
type clientSecretTokenRetriever struct {
|
||||
authority string
|
||||
tenantId string
|
||||
clientId string
|
||||
clientSecret string
|
||||
credential azcore.TokenCredential
|
||||
}
|
||||
|
||||
func (c *clientSecretTokenRetriever) GetCacheKey() string {
|
||||
return fmt.Sprintf("azure|clientsecret|%s|%s|%s|%s", c.authority, c.tenantId, c.clientId, hashSecret(c.clientSecret))
|
||||
}
|
||||
|
||||
func (c *clientSecretTokenRetriever) Init() error {
|
||||
options := &azidentity.ClientSecretCredentialOptions{AuthorityHost: c.authority}
|
||||
if credential, err := azidentity.NewClientSecretCredential(c.tenantId, c.clientId, c.clientSecret, options); err != nil {
|
||||
return err
|
||||
} else {
|
||||
c.credential = credential
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientSecretTokenRetriever) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
accessToken, err := c.credential.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccessToken{Token: accessToken.Token, ExpiresOn: accessToken.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
func hashSecret(secret string) string {
|
||||
hash := sha256.New()
|
||||
_, _ = hash.Write([]byte(secret))
|
||||
return fmt.Sprintf("%x", hash.Sum(nil))
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
package aztokenprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var getAccessTokenFunc func(credential TokenRetriever, scopes []string)
|
||||
|
||||
type tokenCacheFake struct{}
|
||||
|
||||
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenRetriever, scopes []string) (string, error) {
|
||||
getAccessTokenFunc(credential, scopes)
|
||||
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
settings := &azsettings.AzureSettings{}
|
||||
|
||||
scopes := []string{
|
||||
"https://management.azure.com/.default",
|
||||
}
|
||||
|
||||
original := azureTokenCache
|
||||
azureTokenCache = &tokenCacheFake{}
|
||||
t.Cleanup(func() { azureTokenCache = original })
|
||||
|
||||
t.Run("when managed identities enabled", func(t *testing.T) {
|
||||
settings.ManagedIdentityEnabled = true
|
||||
|
||||
t.Run("should resolve managed identity retriever if auth type is managed identity", func(t *testing.T) {
|
||||
credentials := &azcredentials.AzureManagedIdentityCredentials{}
|
||||
|
||||
provider, err := NewAzureAccessTokenProvider(settings, credentials)
|
||||
require.NoError(t, err)
|
||||
|
||||
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
|
||||
assert.IsType(t, &managedIdentityTokenRetriever{}, credential)
|
||||
}
|
||||
|
||||
_, err = provider.GetAccessToken(ctx, scopes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("should resolve client secret retriever if auth type is client secret", func(t *testing.T) {
|
||||
credentials := &azcredentials.AzureClientSecretCredentials{}
|
||||
|
||||
provider, err := NewAzureAccessTokenProvider(settings, credentials)
|
||||
require.NoError(t, err)
|
||||
|
||||
getAccessTokenFunc = func(credential TokenRetriever, scopes []string) {
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, credential)
|
||||
}
|
||||
|
||||
_, err = provider.GetAccessToken(ctx, scopes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when managed identities disabled", func(t *testing.T) {
|
||||
settings.ManagedIdentityEnabled = false
|
||||
|
||||
t.Run("should return error if auth type is managed identity", func(t *testing.T) {
|
||||
credentials := &azcredentials.AzureManagedIdentityCredentials{}
|
||||
|
||||
_, err := NewAzureAccessTokenProvider(settings, credentials)
|
||||
assert.Error(t, err, "managed identity authentication is not enabled in Grafana config")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
|
||||
credentials := &azcredentials.AzureClientSecretCredentials{
|
||||
AzureCloud: azsettings.AzurePublic,
|
||||
Authority: "",
|
||||
TenantId: "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4",
|
||||
ClientId: "1af7c188-e5b6-4f96-81b8-911761bdd459",
|
||||
ClientSecret: "0416d95e-8af8-472c-aaa3-15c93c46080a",
|
||||
}
|
||||
|
||||
t.Run("should return clientSecretTokenRetriever with values", func(t *testing.T) {
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://login.microsoftonline.com/", credential.authority)
|
||||
assert.Equal(t, "7dcf1d1a-4ec0-41f2-ac29-c1538a698bc4", credential.tenantId)
|
||||
assert.Equal(t, "1af7c188-e5b6-4f96-81b8-911761bdd459", credential.clientId)
|
||||
assert.Equal(t, "0416d95e-8af8-472c-aaa3-15c93c46080a", credential.clientSecret)
|
||||
})
|
||||
|
||||
t.Run("authority should selected based on cloud", func(t *testing.T) {
|
||||
originalCloud := credentials.AzureCloud
|
||||
defer func() { credentials.AzureCloud = originalCloud }()
|
||||
|
||||
credentials.AzureCloud = azsettings.AzureChina
|
||||
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://login.chinacloudapi.cn/", credential.authority)
|
||||
})
|
||||
|
||||
t.Run("explicitly set authority should have priority over cloud", func(t *testing.T) {
|
||||
originalCloud := credentials.AzureCloud
|
||||
defer func() { credentials.AzureCloud = originalCloud }()
|
||||
|
||||
credentials.AzureCloud = azsettings.AzureChina
|
||||
credentials.Authority = "https://another.com/"
|
||||
|
||||
result := getClientSecretTokenRetriever(credentials)
|
||||
assert.IsType(t, &clientSecretTokenRetriever{}, result)
|
||||
|
||||
credential := (result).(*clientSecretTokenRetriever)
|
||||
|
||||
assert.Equal(t, "https://another.com/", credential.authority)
|
||||
})
|
||||
}
|
@ -6,9 +6,11 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/metrics"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -6,15 +6,17 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/deprecated"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -3,10 +3,11 @@ package azuremonitor
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
)
|
||||
|
||||
// Azure cloud names specific to Azure Monitor
|
||||
|
@ -3,10 +3,11 @@ package azuremonitor
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -3,10 +3,11 @@ package azuremonitor
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azhttpclient"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azhttpclient"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/deprecated"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
)
|
||||
|
@ -5,10 +5,11 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/deprecated"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -11,13 +11,14 @@ import (
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azlog"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/loganalytics"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/macros"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -1,7 +1,8 @@
|
||||
package azuremonitor
|
||||
|
||||
import (
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azsettings"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/deprecated"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/types"
|
||||
)
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -5,10 +5,11 @@ import (
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azhttpclient"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azhttpclient"
|
||||
"github.com/grafana/grafana/pkg/util/maputil"
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user