mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
PluginProxy: Split implementations of token providers (#32820)
* Split implementations of token providers * Fix imports * Fix code racing in unit tests
This commit is contained in:
parent
c1034f3118
commit
19f520d891
@ -10,7 +10,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// ApplyRoute should use the plugin route data to set auth headers and custom headers.
|
||||
@ -54,38 +53,37 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
|
||||
logger.Error("Failed to set plugin route body content", "error", err)
|
||||
}
|
||||
|
||||
tokenProvider := newAccessTokenProvider(ds, route)
|
||||
|
||||
if route.TokenAuth != nil {
|
||||
if token, err := tokenProvider.getAccessToken(data); err != nil {
|
||||
if tokenProvider := getTokenProvider(ctx, ds, route, data); tokenProvider != nil {
|
||||
if token, err := tokenProvider.getAccessToken(); err != nil {
|
||||
logger.Error("Failed to get access token", "error", err)
|
||||
} else {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
}
|
||||
}
|
||||
|
||||
authenticationType := ds.JsonData.Get("authenticationType").MustString("jwt")
|
||||
if route.JwtTokenAuth != nil && authenticationType == "jwt" {
|
||||
if token, err := tokenProvider.getJwtAccessToken(ctx, data); err != nil {
|
||||
logger.Error("Failed to get access token", "error", err)
|
||||
} else {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
}
|
||||
}
|
||||
|
||||
if authenticationType == "gce" {
|
||||
tokenSrc, err := google.DefaultTokenSource(ctx, route.JwtTokenAuth.Scopes...)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default token from meta data server", "error", err)
|
||||
} else {
|
||||
token, err := tokenSrc.Token()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default access token from meta data server", "error", err)
|
||||
} else {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Requesting", "url", req.URL.String())
|
||||
}
|
||||
|
||||
func getTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
data templateData) accessTokenProvider {
|
||||
authenticationType := ds.JsonData.Get("authenticationType").MustString()
|
||||
|
||||
switch authenticationType {
|
||||
case "gce":
|
||||
return newGceAccessTokenProvider(ctx, ds, pluginRoute)
|
||||
case "jwt":
|
||||
if pluginRoute.JwtTokenAuth != nil {
|
||||
return newJwtAccessTokenProvider(ctx, ds, pluginRoute, data)
|
||||
}
|
||||
default:
|
||||
// Fallback to authentication options when authentication type isn't explicitly configured
|
||||
if pluginRoute.TokenAuth != nil {
|
||||
return newGenericAccessTokenProvider(ds, pluginRoute, data)
|
||||
}
|
||||
if pluginRoute.JwtTokenAuth != nil {
|
||||
return newJwtAccessTokenProvider(ctx, ds, pluginRoute, data)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -249,7 +249,10 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
|
||||
json, err := ioutil.ReadFile("./test-data/access-token-1.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
originalClient := client
|
||||
client = newFakeHTTPClient(t, json)
|
||||
defer func() { client = originalClient }()
|
||||
|
||||
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", &setting.Cfg{})
|
||||
require.NoError(t, err)
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds)
|
||||
|
12
pkg/api/pluginproxy/token_provider.go
Normal file
12
pkg/api/pluginproxy/token_provider.go
Normal file
@ -0,0 +1,12 @@
|
||||
package pluginproxy
|
||||
|
||||
import "time"
|
||||
|
||||
type accessTokenProvider interface {
|
||||
getAccessToken() (string, error)
|
||||
}
|
||||
|
||||
var (
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
41
pkg/api/pluginproxy/token_provider_gce.go
Normal file
41
pkg/api/pluginproxy/token_provider_gce.go
Normal file
@ -0,0 +1,41 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
type gceAccessTokenProvider struct {
|
||||
datasourceId int64
|
||||
datasourceVersion int
|
||||
ctx context.Context
|
||||
route *plugins.AppPluginRoute
|
||||
}
|
||||
|
||||
func newGceAccessTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute) *gceAccessTokenProvider {
|
||||
return &gceAccessTokenProvider{
|
||||
datasourceId: ds.Id,
|
||||
datasourceVersion: ds.Version,
|
||||
ctx: ctx,
|
||||
route: pluginRoute,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *gceAccessTokenProvider) getAccessToken() (string, error) {
|
||||
tokenSrc, err := google.DefaultTokenSource(provider.ctx, provider.route.JwtTokenAuth.Scopes...)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default token from meta data server", "error", err)
|
||||
return "", err
|
||||
} else {
|
||||
token, err := tokenSrc.Token()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default access token from meta data server", "error", err)
|
||||
return "", err
|
||||
} else {
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
@ -2,7 +2,6 @@ package pluginproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -11,22 +10,14 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"golang.org/x/oauth2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
tokenCache = tokenCacheType{
|
||||
cache: map[string]*jwtToken{},
|
||||
}
|
||||
oauthJwtTokenCache = oauthJwtTokenCacheType{
|
||||
cache: map[string]*oauth2.Token{},
|
||||
}
|
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type tokenCacheType struct {
|
||||
@ -34,15 +25,11 @@ type tokenCacheType struct {
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
type oauthJwtTokenCacheType struct {
|
||||
cache map[string]*oauth2.Token
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
type accessTokenProvider struct {
|
||||
route *plugins.AppPluginRoute
|
||||
type genericAccessTokenProvider struct {
|
||||
datasourceId int64
|
||||
datasourceVersion int
|
||||
route *plugins.AppPluginRoute
|
||||
data templateData
|
||||
}
|
||||
|
||||
type jwtToken struct {
|
||||
@ -81,15 +68,17 @@ func (token *jwtToken) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.AppPluginRoute) *accessTokenProvider {
|
||||
return &accessTokenProvider{
|
||||
func newGenericAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
data templateData) *genericAccessTokenProvider {
|
||||
return &genericAccessTokenProvider{
|
||||
datasourceId: ds.Id,
|
||||
datasourceVersion: ds.Version,
|
||||
route: pluginRoute,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) {
|
||||
func (provider *genericAccessTokenProvider) getAccessToken() (string, error) {
|
||||
tokenCache.Lock()
|
||||
defer tokenCache.Unlock()
|
||||
if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
|
||||
@ -99,14 +88,14 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
|
||||
}
|
||||
}
|
||||
|
||||
urlInterpolated, err := interpolateString(provider.route.TokenAuth.Url, data)
|
||||
urlInterpolated, err := interpolateString(provider.route.TokenAuth.Url, provider.data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
params := make(url.Values)
|
||||
for key, value := range provider.route.TokenAuth.Params {
|
||||
interpolatedParam, err := interpolateString(value, data)
|
||||
interpolatedParam, err := interpolateString(value, provider.data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -141,68 +130,6 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) {
|
||||
oauthJwtTokenCache.Lock()
|
||||
defer oauthJwtTokenCache.Unlock()
|
||||
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
|
||||
if cachedToken.Expiry.After(timeNow().Add(time.Second * 10)) {
|
||||
logger.Debug("Using token from cache")
|
||||
return cachedToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
conf := &jwt.Config{}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["client_email"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.Email = interpolatedVal
|
||||
}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["private_key"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.PrivateKey = []byte(interpolatedVal)
|
||||
}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["token_uri"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.TokenURL = interpolatedVal
|
||||
}
|
||||
|
||||
conf.Scopes = provider.route.JwtTokenAuth.Scopes
|
||||
|
||||
token, err := getTokenSource(conf, ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
|
||||
|
||||
logger.Info("Got new access token", "ExpiresOn", token.Expiry)
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// getTokenSource gets a token source.
|
||||
// Stubbable by tests.
|
||||
var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
|
||||
tokenSrc := conf.TokenSource(ctx)
|
||||
token, err := tokenSrc.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (provider *accessTokenProvider) getAccessTokenCacheKey() string {
|
||||
func (provider *genericAccessTokenProvider) getAccessTokenCacheKey() string {
|
||||
return fmt.Sprintf("%v_%v_%v_%v", provider.datasourceId, provider.datasourceVersion, provider.route.Path, provider.route.Method)
|
||||
}
|
109
pkg/api/pluginproxy/token_provider_jwt.go
Normal file
109
pkg/api/pluginproxy/token_provider_jwt.go
Normal file
@ -0,0 +1,109 @@
|
||||
package pluginproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthJwtTokenCache = oauthJwtTokenCacheType{
|
||||
cache: map[string]*oauth2.Token{},
|
||||
}
|
||||
)
|
||||
|
||||
type oauthJwtTokenCacheType struct {
|
||||
cache map[string]*oauth2.Token
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
type jwtAccessTokenProvider struct {
|
||||
datasourceId int64
|
||||
datasourceVersion int
|
||||
ctx context.Context
|
||||
route *plugins.AppPluginRoute
|
||||
data templateData
|
||||
}
|
||||
|
||||
func newJwtAccessTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
|
||||
data templateData) *jwtAccessTokenProvider {
|
||||
return &jwtAccessTokenProvider{
|
||||
datasourceId: ds.Id,
|
||||
datasourceVersion: ds.Version,
|
||||
ctx: ctx,
|
||||
route: pluginRoute,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *jwtAccessTokenProvider) getAccessToken() (string, error) {
|
||||
oauthJwtTokenCache.Lock()
|
||||
defer oauthJwtTokenCache.Unlock()
|
||||
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
|
||||
if cachedToken.Expiry.After(timeNow().Add(time.Second * 10)) {
|
||||
logger.Debug("Using token from cache")
|
||||
return cachedToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
conf := &jwt.Config{}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["client_email"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, provider.data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.Email = interpolatedVal
|
||||
}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["private_key"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, provider.data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.PrivateKey = []byte(interpolatedVal)
|
||||
}
|
||||
|
||||
if val, ok := provider.route.JwtTokenAuth.Params["token_uri"]; ok {
|
||||
interpolatedVal, err := interpolateString(val, provider.data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
conf.TokenURL = interpolatedVal
|
||||
}
|
||||
|
||||
conf.Scopes = provider.route.JwtTokenAuth.Scopes
|
||||
|
||||
token, err := getTokenSource(conf, provider.ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
|
||||
|
||||
logger.Info("Got new access token", "ExpiresOn", token.Expiry)
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// getTokenSource gets a token source.
|
||||
// Stubbable by tests.
|
||||
var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
|
||||
tokenSrc := conf.TokenSource(ctx)
|
||||
token, err := tokenSrc.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (provider *jwtAccessTokenProvider) getAccessTokenCacheKey() string {
|
||||
return fmt.Sprintf("%v_%v_%v_%v", provider.datasourceId, provider.datasourceVersion, provider.route.Path, provider.route.Method)
|
||||
}
|
@ -66,8 +66,8 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
setUp(t, func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{AccessToken: "abc"}, nil
|
||||
})
|
||||
provider := newAccessTokenProvider(ds, pluginRoute)
|
||||
token, err := provider.getJwtAccessToken(context.Background(), templateData)
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
|
||||
token, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc", token)
|
||||
@ -85,8 +85,8 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
return &oauth2.Token{AccessToken: "abc"}, nil
|
||||
})
|
||||
|
||||
provider := newAccessTokenProvider(ds, pluginRoute)
|
||||
_, err := provider.getJwtAccessToken(context.Background(), templateData)
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
|
||||
_, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@ -96,15 +96,15 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
|
||||
AccessToken: "abc",
|
||||
Expiry: time.Now().Add(1 * time.Minute)}, nil
|
||||
})
|
||||
provider := newAccessTokenProvider(ds, pluginRoute)
|
||||
token1, err := provider.getJwtAccessToken(context.Background(), templateData)
|
||||
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
|
||||
token1, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "abc", token1)
|
||||
|
||||
getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{AccessToken: "error: cache not used"}, nil
|
||||
}
|
||||
token2, err := provider.getJwtAccessToken(context.Background(), templateData)
|
||||
token2, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "abc", token2)
|
||||
})
|
||||
@ -162,7 +162,7 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
|
||||
|
||||
mockTimeNow(time.Now())
|
||||
defer resetTimeNow()
|
||||
provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute)
|
||||
provider := newGenericAccessTokenProvider(&models.DataSource{}, pluginRoute, templateData)
|
||||
|
||||
testCases := []tokenTestDescription{
|
||||
{
|
||||
@ -216,12 +216,12 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
|
||||
token["expires_on"] = testCase.expiresOn
|
||||
}
|
||||
|
||||
accessToken, err := provider.getAccessToken(templateData)
|
||||
accessToken, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
|
||||
// getAccessToken should use internal cache
|
||||
accessToken, err = provider.getAccessToken(templateData)
|
||||
accessToken, err = provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
assert.Equal(t, 1, authCalls)
|
||||
@ -244,20 +244,20 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
|
||||
|
||||
mockTimeNow(time.Now())
|
||||
defer resetTimeNow()
|
||||
provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute)
|
||||
provider := newGenericAccessTokenProvider(&models.DataSource{}, pluginRoute, templateData)
|
||||
|
||||
token = map[string]interface{}{
|
||||
"access_token": "2YotnFZFEjr1zCsicMWpAA",
|
||||
"token_type": "3600",
|
||||
"refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",
|
||||
}
|
||||
accessToken, err := provider.getAccessToken(templateData)
|
||||
accessToken, err := provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
|
||||
mockTimeNow(timeNow().Add(3601 * time.Second))
|
||||
|
||||
accessToken, err = provider.getAccessToken(templateData)
|
||||
accessToken, err = provider.getAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token["access_token"], accessToken)
|
||||
assert.Equal(t, 2, authCalls)
|
Loading…
Reference in New Issue
Block a user