PluginProxy: Split implementations of token providers (#32820)

* Split implementations of token providers

* Fix imports

* Fix code racing in unit tests
This commit is contained in:
Sergey Kostrukov 2021-05-03 02:46:32 -10:00 committed by GitHub
parent c1034f3118
commit 19f520d891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 215 additions and 125 deletions

View File

@ -10,7 +10,6 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/util"
"golang.org/x/oauth2/google"
)
// ApplyRoute should use the plugin route data to set auth headers and custom headers.
@ -54,38 +53,37 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
logger.Error("Failed to set plugin route body content", "error", err)
}
tokenProvider := newAccessTokenProvider(ds, route)
if route.TokenAuth != nil {
if token, err := tokenProvider.getAccessToken(data); err != nil {
if tokenProvider := getTokenProvider(ctx, ds, route, data); tokenProvider != nil {
if token, err := tokenProvider.getAccessToken(); err != nil {
logger.Error("Failed to get access token", "error", err)
} else {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}
}
authenticationType := ds.JsonData.Get("authenticationType").MustString("jwt")
if route.JwtTokenAuth != nil && authenticationType == "jwt" {
if token, err := tokenProvider.getJwtAccessToken(ctx, data); err != nil {
logger.Error("Failed to get access token", "error", err)
} else {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}
}
if authenticationType == "gce" {
tokenSrc, err := google.DefaultTokenSource(ctx, route.JwtTokenAuth.Scopes...)
if err != nil {
logger.Error("Failed to get default token from meta data server", "error", err)
} else {
token, err := tokenSrc.Token()
if err != nil {
logger.Error("Failed to get default access token from meta data server", "error", err)
} else {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
}
}
}
logger.Info("Requesting", "url", req.URL.String())
}
func getTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
data templateData) accessTokenProvider {
authenticationType := ds.JsonData.Get("authenticationType").MustString()
switch authenticationType {
case "gce":
return newGceAccessTokenProvider(ctx, ds, pluginRoute)
case "jwt":
if pluginRoute.JwtTokenAuth != nil {
return newJwtAccessTokenProvider(ctx, ds, pluginRoute, data)
}
default:
// Fallback to authentication options when authentication type isn't explicitly configured
if pluginRoute.TokenAuth != nil {
return newGenericAccessTokenProvider(ds, pluginRoute, data)
}
if pluginRoute.JwtTokenAuth != nil {
return newJwtAccessTokenProvider(ctx, ds, pluginRoute, data)
}
}
return nil
}

View File

@ -249,7 +249,10 @@ func TestDataSourceProxy_routeRule(t *testing.T) {
json, err := ioutil.ReadFile("./test-data/access-token-1.json")
require.NoError(t, err)
originalClient := client
client = newFakeHTTPClient(t, json)
defer func() { client = originalClient }()
proxy, err := NewDataSourceProxy(ds, plugin, ctx, "pathwithtoken1", &setting.Cfg{})
require.NoError(t, err)
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, plugin.Routes[0], proxy.ds)

View File

@ -0,0 +1,12 @@
package pluginproxy
import "time"
type accessTokenProvider interface {
getAccessToken() (string, error)
}
var (
// timeNow makes it possible to test usage of time
timeNow = time.Now
)

View File

@ -0,0 +1,41 @@
package pluginproxy
import (
"context"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"golang.org/x/oauth2/google"
)
type gceAccessTokenProvider struct {
datasourceId int64
datasourceVersion int
ctx context.Context
route *plugins.AppPluginRoute
}
func newGceAccessTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute) *gceAccessTokenProvider {
return &gceAccessTokenProvider{
datasourceId: ds.Id,
datasourceVersion: ds.Version,
ctx: ctx,
route: pluginRoute,
}
}
func (provider *gceAccessTokenProvider) getAccessToken() (string, error) {
tokenSrc, err := google.DefaultTokenSource(provider.ctx, provider.route.JwtTokenAuth.Scopes...)
if err != nil {
logger.Error("Failed to get default token from meta data server", "error", err)
return "", err
} else {
token, err := tokenSrc.Token()
if err != nil {
logger.Error("Failed to get default access token from meta data server", "error", err)
return "", err
} else {
return token.AccessToken, nil
}
}
}

View File

@ -2,7 +2,6 @@ package pluginproxy
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
@ -11,22 +10,14 @@ import (
"sync"
"time"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"golang.org/x/oauth2/jwt"
)
var (
tokenCache = tokenCacheType{
cache: map[string]*jwtToken{},
}
oauthJwtTokenCache = oauthJwtTokenCacheType{
cache: map[string]*oauth2.Token{},
}
// timeNow makes it possible to test usage of time
timeNow = time.Now
)
type tokenCacheType struct {
@ -34,15 +25,11 @@ type tokenCacheType struct {
sync.Mutex
}
type oauthJwtTokenCacheType struct {
cache map[string]*oauth2.Token
sync.Mutex
}
type accessTokenProvider struct {
route *plugins.AppPluginRoute
type genericAccessTokenProvider struct {
datasourceId int64
datasourceVersion int
route *plugins.AppPluginRoute
data templateData
}
type jwtToken struct {
@ -81,15 +68,17 @@ func (token *jwtToken) UnmarshalJSON(b []byte) error {
return nil
}
func newAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.AppPluginRoute) *accessTokenProvider {
return &accessTokenProvider{
func newGenericAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
data templateData) *genericAccessTokenProvider {
return &genericAccessTokenProvider{
datasourceId: ds.Id,
datasourceVersion: ds.Version,
route: pluginRoute,
data: data,
}
}
func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) {
func (provider *genericAccessTokenProvider) getAccessToken() (string, error) {
tokenCache.Lock()
defer tokenCache.Unlock()
if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
@ -99,14 +88,14 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
}
}
urlInterpolated, err := interpolateString(provider.route.TokenAuth.Url, data)
urlInterpolated, err := interpolateString(provider.route.TokenAuth.Url, provider.data)
if err != nil {
return "", err
}
params := make(url.Values)
for key, value := range provider.route.TokenAuth.Params {
interpolatedParam, err := interpolateString(value, data)
interpolatedParam, err := interpolateString(value, provider.data)
if err != nil {
return "", err
}
@ -141,68 +130,6 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
return token.AccessToken, nil
}
func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) {
oauthJwtTokenCache.Lock()
defer oauthJwtTokenCache.Unlock()
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
if cachedToken.Expiry.After(timeNow().Add(time.Second * 10)) {
logger.Debug("Using token from cache")
return cachedToken.AccessToken, nil
}
}
conf := &jwt.Config{}
if val, ok := provider.route.JwtTokenAuth.Params["client_email"]; ok {
interpolatedVal, err := interpolateString(val, data)
if err != nil {
return "", err
}
conf.Email = interpolatedVal
}
if val, ok := provider.route.JwtTokenAuth.Params["private_key"]; ok {
interpolatedVal, err := interpolateString(val, data)
if err != nil {
return "", err
}
conf.PrivateKey = []byte(interpolatedVal)
}
if val, ok := provider.route.JwtTokenAuth.Params["token_uri"]; ok {
interpolatedVal, err := interpolateString(val, data)
if err != nil {
return "", err
}
conf.TokenURL = interpolatedVal
}
conf.Scopes = provider.route.JwtTokenAuth.Scopes
token, err := getTokenSource(conf, ctx)
if err != nil {
return "", err
}
oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
logger.Info("Got new access token", "ExpiresOn", token.Expiry)
return token.AccessToken, nil
}
// getTokenSource gets a token source.
// Stubbable by tests.
var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
tokenSrc := conf.TokenSource(ctx)
token, err := tokenSrc.Token()
if err != nil {
return nil, err
}
return token, nil
}
func (provider *accessTokenProvider) getAccessTokenCacheKey() string {
func (provider *genericAccessTokenProvider) getAccessTokenCacheKey() string {
return fmt.Sprintf("%v_%v_%v_%v", provider.datasourceId, provider.datasourceVersion, provider.route.Path, provider.route.Method)
}

View File

@ -0,0 +1,109 @@
package pluginproxy
import (
"context"
"fmt"
"sync"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jwt"
)
var (
oauthJwtTokenCache = oauthJwtTokenCacheType{
cache: map[string]*oauth2.Token{},
}
)
type oauthJwtTokenCacheType struct {
cache map[string]*oauth2.Token
sync.Mutex
}
type jwtAccessTokenProvider struct {
datasourceId int64
datasourceVersion int
ctx context.Context
route *plugins.AppPluginRoute
data templateData
}
func newJwtAccessTokenProvider(ctx context.Context, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
data templateData) *jwtAccessTokenProvider {
return &jwtAccessTokenProvider{
datasourceId: ds.Id,
datasourceVersion: ds.Version,
ctx: ctx,
route: pluginRoute,
data: data,
}
}
func (provider *jwtAccessTokenProvider) getAccessToken() (string, error) {
oauthJwtTokenCache.Lock()
defer oauthJwtTokenCache.Unlock()
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
if cachedToken.Expiry.After(timeNow().Add(time.Second * 10)) {
logger.Debug("Using token from cache")
return cachedToken.AccessToken, nil
}
}
conf := &jwt.Config{}
if val, ok := provider.route.JwtTokenAuth.Params["client_email"]; ok {
interpolatedVal, err := interpolateString(val, provider.data)
if err != nil {
return "", err
}
conf.Email = interpolatedVal
}
if val, ok := provider.route.JwtTokenAuth.Params["private_key"]; ok {
interpolatedVal, err := interpolateString(val, provider.data)
if err != nil {
return "", err
}
conf.PrivateKey = []byte(interpolatedVal)
}
if val, ok := provider.route.JwtTokenAuth.Params["token_uri"]; ok {
interpolatedVal, err := interpolateString(val, provider.data)
if err != nil {
return "", err
}
conf.TokenURL = interpolatedVal
}
conf.Scopes = provider.route.JwtTokenAuth.Scopes
token, err := getTokenSource(conf, provider.ctx)
if err != nil {
return "", err
}
oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
logger.Info("Got new access token", "ExpiresOn", token.Expiry)
return token.AccessToken, nil
}
// getTokenSource gets a token source.
// Stubbable by tests.
var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
tokenSrc := conf.TokenSource(ctx)
token, err := tokenSrc.Token()
if err != nil {
return nil, err
}
return token, nil
}
func (provider *jwtAccessTokenProvider) getAccessTokenCacheKey() string {
return fmt.Sprintf("%v_%v_%v_%v", provider.datasourceId, provider.datasourceVersion, provider.route.Path, provider.route.Method)
}

View File

@ -66,8 +66,8 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
setUp(t, func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
return &oauth2.Token{AccessToken: "abc"}, nil
})
provider := newAccessTokenProvider(ds, pluginRoute)
token, err := provider.getJwtAccessToken(context.Background(), templateData)
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
token, err := provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token)
@ -85,8 +85,8 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
return &oauth2.Token{AccessToken: "abc"}, nil
})
provider := newAccessTokenProvider(ds, pluginRoute)
_, err := provider.getJwtAccessToken(context.Background(), templateData)
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
_, err := provider.getAccessToken()
require.NoError(t, err)
})
@ -96,15 +96,15 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
AccessToken: "abc",
Expiry: time.Now().Add(1 * time.Minute)}, nil
})
provider := newAccessTokenProvider(ds, pluginRoute)
token1, err := provider.getJwtAccessToken(context.Background(), templateData)
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, templateData)
token1, err := provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token1)
getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
return &oauth2.Token{AccessToken: "error: cache not used"}, nil
}
token2, err := provider.getJwtAccessToken(context.Background(), templateData)
token2, err := provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token2)
})
@ -162,7 +162,7 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
mockTimeNow(time.Now())
defer resetTimeNow()
provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute)
provider := newGenericAccessTokenProvider(&models.DataSource{}, pluginRoute, templateData)
testCases := []tokenTestDescription{
{
@ -216,12 +216,12 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
token["expires_on"] = testCase.expiresOn
}
accessToken, err := provider.getAccessToken(templateData)
accessToken, err := provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
// getAccessToken should use internal cache
accessToken, err = provider.getAccessToken(templateData)
accessToken, err = provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
assert.Equal(t, 1, authCalls)
@ -244,20 +244,20 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
mockTimeNow(time.Now())
defer resetTimeNow()
provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute)
provider := newGenericAccessTokenProvider(&models.DataSource{}, pluginRoute, templateData)
token = map[string]interface{}{
"access_token": "2YotnFZFEjr1zCsicMWpAA",
"token_type": "3600",
"refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",
}
accessToken, err := provider.getAccessToken(templateData)
accessToken, err := provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
mockTimeNow(timeNow().Add(3601 * time.Second))
accessToken, err = provider.getAccessToken(templateData)
accessToken, err = provider.getAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
assert.Equal(t, 2, authCalls)