mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Infra: Azure authentication in HttpClientProvider (#36932)
* Azure middleware in HttpClientProxy * Azure authentication under feature flag * Minor fixes * Add prefixes to not clash with JsonData * Return error if JsonData cannot be parsed * Return original string if URL invalid * Tests for datasource_cache
This commit is contained in:
113
pkg/infra/httpclient/httpclientprovider/azure_middleware.go
Normal file
113
pkg/infra/httpclient/httpclientprovider/azure_middleware.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package httpclientprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
|
||||
)
|
||||
|
||||
const azureMiddlewareName = "AzureAuthentication.Provider"
|
||||
|
||||
func AzureMiddleware(cfg *setting.Cfg) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(azureMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
if enabled, err := isAzureAuthenticationEnabled(opts.CustomOptions); err != nil {
|
||||
return errorResponse(err)
|
||||
} else if !enabled {
|
||||
return next
|
||||
}
|
||||
|
||||
credentials, err := getAzureCredentials(opts.CustomOptions)
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
} else if credentials == nil {
|
||||
credentials = getDefaultAzureCredentials(cfg)
|
||||
}
|
||||
|
||||
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, credentials)
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
scopes, err := getAzureEndpointScopes(opts.CustomOptions)
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
return aztokenprovider.ApplyAuth(tokenProvider, scopes, next)
|
||||
})
|
||||
}
|
||||
|
||||
func errorResponse(err error) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("invalid Azure configuration: %s", err)
|
||||
})
|
||||
}
|
||||
|
||||
func isAzureAuthenticationEnabled(customOptions map[string]interface{}) (bool, error) {
|
||||
if untypedValue, ok := customOptions["_azureAuth"]; !ok {
|
||||
return false, nil
|
||||
} else if value, ok := untypedValue.(bool); !ok {
|
||||
err := fmt.Errorf("the field 'azureAuth' should be a bool")
|
||||
return false, err
|
||||
} else {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureCredentials(customOptions map[string]interface{}) (azcredentials.AzureCredentials, error) {
|
||||
if untypedValue, ok := customOptions["_azureCredentials"]; !ok {
|
||||
return nil, nil
|
||||
} else if value, ok := untypedValue.(azcredentials.AzureCredentials); !ok {
|
||||
err := fmt.Errorf("the field 'azureCredentials' should be a valid credentials object")
|
||||
return nil, err
|
||||
} else {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultAzureCredentials(cfg *setting.Cfg) azcredentials.AzureCredentials {
|
||||
if cfg.Azure.ManagedIdentityEnabled {
|
||||
return &azcredentials.AzureManagedIdentityCredentials{}
|
||||
} else {
|
||||
return &azcredentials.AzureClientSecretCredentials{
|
||||
AzureCloud: cfg.Azure.Cloud,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureEndpointResourceId(customOptions map[string]interface{}) (*url.URL, error) {
|
||||
var value string
|
||||
if untypedValue, ok := customOptions["azureEndpointResourceId"]; !ok {
|
||||
err := fmt.Errorf("the field 'azureEndpointResourceId' should be set")
|
||||
return nil, err
|
||||
} else if value, ok = untypedValue.(string); !ok {
|
||||
err := fmt.Errorf("the field 'azureEndpointResourceId' should be a string")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceId, err := url.Parse(value)
|
||||
if err != nil || resourceId.Scheme == "" || resourceId.Host == "" {
|
||||
err := fmt.Errorf("invalid endpoint Resource ID URL '%s'", value)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resourceId, nil
|
||||
}
|
||||
|
||||
func getAzureEndpointScopes(customOptions map[string]interface{}) ([]string, error) {
|
||||
resourceId, err := getAzureEndpointResourceId(customOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceId.Path = path.Join(resourceId.Path, ".default")
|
||||
scopes := []string{resourceId.String()}
|
||||
|
||||
return scopes, nil
|
||||
}
|
@@ -34,6 +34,10 @@ func New(cfg *setting.Cfg) httpclient.Provider {
|
||||
|
||||
setDefaultTimeoutOptions(cfg)
|
||||
|
||||
if cfg.FeatureToggles["httpclientprovider_azure_auth"] {
|
||||
middlewares = append(middlewares, AzureMiddleware(cfg))
|
||||
}
|
||||
|
||||
return newProviderFunc(sdkhttpclient.ProviderOptions{
|
||||
Middlewares: middlewares,
|
||||
ConfigureTransport: func(opts sdkhttpclient.Options, transport *http.Transport) {
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
)
|
||||
|
||||
func (ds *DataSource) getTimeout() time.Duration {
|
||||
@@ -66,10 +67,14 @@ func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider, customMiddl
|
||||
return t.roundTripper, nil
|
||||
}
|
||||
|
||||
opts := ds.HTTPClientOptions()
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.Middlewares = customMiddlewares
|
||||
|
||||
rt, err := provider.GetTransport(opts)
|
||||
rt, err := provider.GetTransport(*opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +87,7 @@ func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider, customMiddl
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
|
||||
func (ds *DataSource) HTTPClientOptions() (*sdkhttpclient.Options, error) {
|
||||
tlsOptions := ds.TLSOptions()
|
||||
timeouts := &sdkhttpclient.TimeoutOptions{
|
||||
Timeout: ds.getTimeout(),
|
||||
@@ -95,7 +100,7 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
|
||||
MaxIdleConnsPerHost: sdkhttpclient.DefaultTimeoutOptions.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: sdkhttpclient.DefaultTimeoutOptions.IdleConnTimeout,
|
||||
}
|
||||
opts := sdkhttpclient.Options{
|
||||
opts := &sdkhttpclient.Options{
|
||||
Timeouts: timeouts,
|
||||
Headers: getCustomHeaders(ds.JsonData, ds.DecryptedValues()),
|
||||
Labels: map[string]string{
|
||||
@@ -121,6 +126,19 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
|
||||
}
|
||||
}
|
||||
|
||||
if ds.JsonData != nil && ds.JsonData.Get("azureAuth").MustBool() {
|
||||
credentials, err := azcredentials.FromDatasourceData(ds.JsonData.MustMap(), ds.DecryptedValues())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid Azure credentials: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.CustomOptions["_azureAuth"] = true
|
||||
if credentials != nil {
|
||||
opts.CustomOptions["_azureCredentials"] = credentials
|
||||
}
|
||||
}
|
||||
|
||||
if ds.JsonData != nil && ds.JsonData.Get("sigV4Auth").MustBool(false) {
|
||||
opts.SigV4 = &sdkhttpclient.SigV4Config{
|
||||
Service: awsServiceNamespace(ds.Type),
|
||||
@@ -140,7 +158,7 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (ds *DataSource) TLSOptions() sdkhttpclient.TLSOptions {
|
||||
@@ -180,7 +198,11 @@ func (ds *DataSource) TLSOptions() sdkhttpclient.TLSOptions {
|
||||
}
|
||||
|
||||
func (ds *DataSource) GetTLSConfig(httpClientProvider httpclient.Provider) (*tls.Config, error) {
|
||||
return httpClientProvider.GetTLSConfig(ds.HTTPClientOptions())
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return httpClientProvider.GetTLSConfig(*opts)
|
||||
}
|
||||
|
||||
// getCustomHeaders returns a map with all the to be set headers
|
||||
|
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -394,6 +395,109 @@ func TestDataSource_DecryptedValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDataSource_HTTPClientOptions(t *testing.T) {
|
||||
emptyJsonData := simplejson.New()
|
||||
emptySecureJsonData := map[string][]byte{}
|
||||
|
||||
ds := DataSource{
|
||||
Id: 1,
|
||||
Url: "https://api.example.com",
|
||||
Type: "prometheus",
|
||||
}
|
||||
|
||||
t.Run("Azure authentication", func(t *testing.T) {
|
||||
t.Run("should be disabled if not enabled in JsonData", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, true, opts.CustomOptions["_azureAuth"])
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
|
||||
})
|
||||
|
||||
t.Run("should be enabled if enabled in JsonData without credentials configured", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuth": true,
|
||||
})
|
||||
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, true, opts.CustomOptions["_azureAuth"])
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
|
||||
})
|
||||
|
||||
t.Run("should be enabled if enabled in JsonData with credentials configured", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuth": true,
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
})
|
||||
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, true, opts.CustomOptions["_azureAuth"])
|
||||
|
||||
require.Contains(t, opts.CustomOptions, "_azureCredentials")
|
||||
credentials := opts.CustomOptions["_azureCredentials"]
|
||||
|
||||
assert.IsType(t, &azcredentials.AzureManagedIdentityCredentials{}, credentials)
|
||||
})
|
||||
|
||||
t.Run("should be disabled if disabled in JsonData even with credentials configured", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuth": false,
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
})
|
||||
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, true, opts.CustomOptions["_azureAuth"])
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
|
||||
})
|
||||
|
||||
t.Run("should fail if credentials are invalid", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureAuth": true,
|
||||
"azureCredentials": "invalid",
|
||||
})
|
||||
|
||||
_, err := ds.HTTPClientOptions()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should pass resourceId from JsonData", func(t *testing.T) {
|
||||
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
|
||||
|
||||
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
|
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
|
||||
})
|
||||
|
||||
opts, err := ds.HTTPClientOptions()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, opts.CustomOptions, "azureEndpointResourceId")
|
||||
azureEndpointResourceId := opts.CustomOptions["azureEndpointResourceId"]
|
||||
|
||||
assert.Equal(t, "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", azureEndpointResourceId)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func clearDSProxyCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
83
pkg/tsdb/azuremonitor/azcredentials/builder.go
Normal file
83
pkg/tsdb/azuremonitor/azcredentials/builder.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package azcredentials
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func FromDatasourceData(data map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
|
||||
if credentialsObj, err := getMapOptional(data, "azureCredentials"); err != nil {
|
||||
return nil, err
|
||||
} else if credentialsObj == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return getFromCredentialsObject(credentialsObj, secureData)
|
||||
}
|
||||
}
|
||||
|
||||
func getFromCredentialsObject(credentialsObj map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
|
||||
authType, err := getStringValue(credentialsObj, "authType")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case AzureAuthManagedIdentity:
|
||||
credentials := &AzureManagedIdentityCredentials{}
|
||||
return credentials, nil
|
||||
|
||||
case AzureAuthClientSecret:
|
||||
cloud, err := getStringValue(credentialsObj, "azureCloud")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tenantId, err := getStringValue(credentialsObj, "tenantId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientId, err := getStringValue(credentialsObj, "clientId")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSecret := secureData["azureClientSecret"]
|
||||
|
||||
credentials := &AzureClientSecretCredentials{
|
||||
AzureCloud: cloud,
|
||||
TenantId: tenantId,
|
||||
ClientId: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
}
|
||||
return credentials, nil
|
||||
|
||||
default:
|
||||
err := fmt.Errorf("the authentication type '%s' not supported", authType)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func getMapOptional(obj map[string]interface{}, key string) (map[string]interface{}, error) {
|
||||
if untypedValue, ok := obj[key]; ok {
|
||||
if value, ok := untypedValue.(map[string]interface{}); ok {
|
||||
return value, nil
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be an object", key)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Value optional, not error
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getStringValue(obj map[string]interface{}, key string) (string, error) {
|
||||
if untypedValue, ok := obj[key]; ok {
|
||||
if value, ok := untypedValue.(string); ok {
|
||||
return value, nil
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be a string", key)
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
err := fmt.Errorf("the field '%s' should be set", key)
|
||||
return "", err
|
||||
}
|
||||
}
|
@@ -11,13 +11,17 @@ const authenticationMiddlewareName = "AzureAuthentication"
|
||||
|
||||
func AuthMiddleware(tokenProvider AzureTokenProvider, scopes []string) httpclient.Middleware {
|
||||
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
return ApplyAuth(tokenProvider, scopes, next)
|
||||
})
|
||||
}
|
||||
|
||||
func ApplyAuth(tokenProvider AzureTokenProvider, scopes []string, next http.RoundTripper) http.RoundTripper {
|
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user