diff --git a/pkg/api/pluginproxy/access_token_provider.go b/pkg/api/pluginproxy/access_token_provider.go index fb07bea0020..d6a687dc8fd 100644 --- a/pkg/api/pluginproxy/access_token_provider.go +++ b/pkg/api/pluginproxy/access_token_provider.go @@ -25,6 +25,8 @@ var ( oauthJwtTokenCache = oauthJwtTokenCacheType{ cache: map[string]*oauth2.Token{}, } + // timeNow makes it possible to test usage of time + timeNow = time.Now ) type tokenCacheType struct { @@ -44,9 +46,39 @@ type accessTokenProvider struct { } type jwtToken struct { - ExpiresOn time.Time `json:"-"` - ExpiresOnString string `json:"expires_on"` - AccessToken string `json:"access_token"` + ExpiresOn time.Time + AccessToken string +} + +func (token *jwtToken) UnmarshalJSON(b []byte) error { + var t struct { + AccessToken string `json:"access_token"` + ExpiresOn *json.Number `json:"expires_on"` + ExpiresIn *json.Number `json:"expires_in"` + } + + if err := json.Unmarshal(b, &t); err != nil { + return err + } + + token.AccessToken = t.AccessToken + token.ExpiresOn = timeNow() + + if t.ExpiresOn != nil { + expiresOn, err := t.ExpiresOn.Int64() + if err != nil { + return err + } + token.ExpiresOn = time.Unix(expiresOn, 0) + } else if t.ExpiresIn != nil { + expiresIn, err := t.ExpiresIn.Int64() + if err != nil { + return err + } + token.ExpiresOn = timeNow().Add(time.Duration(expiresIn) * time.Second) + } + + return nil } func newAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.AppPluginRoute) *accessTokenProvider { @@ -61,7 +93,7 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, tokenCache.Lock() defer tokenCache.Unlock() if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found { - if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) { + if cachedToken.ExpiresOn.After(timeNow().Add(time.Second * 10)) { logger.Info("Using token from cache") return cachedToken.AccessToken, nil } @@ -97,12 +129,8 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, return "", err } - expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64) - token.ExpiresOn = time.Unix(expiresOnEpoch, 0) tokenCache.cache[provider.getAccessTokenCacheKey()] = &token - logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn) - return token.AccessToken, nil } @@ -110,7 +138,7 @@ func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data oauthJwtTokenCache.Lock() defer oauthJwtTokenCache.Unlock() if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found { - if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) { + if cachedToken.Expiry.After(timeNow().Add(time.Second * 10)) { logger.Debug("Using token from cache") return cachedToken.AccessToken, nil } diff --git a/pkg/api/pluginproxy/access_token_provider_test.go b/pkg/api/pluginproxy/access_token_provider_test.go index e75748e4660..efbdb9225f6 100644 --- a/pkg/api/pluginproxy/access_token_provider_test.go +++ b/pkg/api/pluginproxy/access_token_provider_test.go @@ -2,6 +2,11 @@ package pluginproxy import ( "context" + "encoding/json" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "strconv" "testing" "time" @@ -12,6 +17,10 @@ import ( "golang.org/x/oauth2/jwt" ) +var ( + token map[string]interface{} +) + func TestAccessToken(t *testing.T) { Convey("Plugin with JWT token auth route", t, func() { pluginRoute := &plugins.AppPluginRoute{ @@ -91,4 +100,175 @@ func TestAccessToken(t *testing.T) { So(token2, ShouldEqual, "abc") }) }) + + Convey("Plugin with token auth route", t, func() { + apiHandler := http.NewServeMux() + server := httptest.NewServer(apiHandler) + defer server.Close() + + pluginRoute := &plugins.AppPluginRoute{ + Path: "pathwithtokenauth1", + Url: "", + Method: "GET", + TokenAuth: &plugins.JwtTokenAuth{ + Url: server.URL + "/oauth/token", + Scopes: []string{ + "https://www.testapi.com/auth/monitoring.read", + "https://www.testapi.com/auth/cloudplatformprojects.readonly", + }, + Params: map[string]string{ + "grant_type": "client_credentials", + "client_id": "{{.JsonData.client_id}}", + "client_secret": "{{.SecureJsonData.client_secret}}", + "audience": "{{.JsonData.audience}}", + "client_name": "datasource_plugin", + }, + }, + } + + templateData := templateData{ + JsonData: map[string]interface{}{ + "client_id": "my_client_id", + "audience": "www.example.com", + }, + SecureJsonData: map[string]string{ + "client_secret": "my_secret", + }, + } + + var authCalls int + apiHandler.HandleFunc("/oauth/token", func(w http.ResponseWriter, req *http.Request) { + err := json.NewEncoder(w).Encode(token) + require.NoError(t, err) + authCalls++ + }) + + Convey("Should parse token, with different fields and types", func() { + type tokenTestDescription struct { + desc string + expiresIn interface{} + expiresOn interface{} + expectedExpiresOn int64 + } + + mockTimeNow(time.Now()) + defer resetTimeNow() + provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute) + + testCases := []tokenTestDescription{ + { + desc: "token with expires_in in string format", + expiresIn: "3600", + expiresOn: nil, + expectedExpiresOn: timeNow().Unix() + 3600, + }, + { + desc: "token with expires_in in int format", + expiresIn: 3600, + expiresOn: nil, + expectedExpiresOn: timeNow().Unix() + 3600, + }, + { + desc: "token with expires_on in string format", + expiresOn: strconv.FormatInt(timeNow().Add(86*time.Minute).Unix(), 10), + expiresIn: nil, + expectedExpiresOn: timeNow().Add(86 * time.Minute).Unix(), + }, + { + desc: "token with expires_on in int format", + expiresOn: timeNow().Add(86 * time.Minute).Unix(), + expiresIn: nil, + expectedExpiresOn: timeNow().Add(86 * time.Minute).Unix(), + }, + { + desc: "token with both expires_on and expires_in, should prioritize expiresOn", + expiresIn: 5200, + expiresOn: timeNow().Add(1 * time.Hour).Unix(), + expectedExpiresOn: timeNow().Add(1 * time.Hour).Unix(), + }, + } + for _, testCase := range testCases { + Convey(testCase.desc, func() { + clearTokenCache() + // reset the httphandler counter + authCalls = 0 + + token = map[string]interface{}{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + } + + if testCase.expiresIn != nil { + token["expires_in"] = testCase.expiresIn + } + + if testCase.expiresOn != nil { + token["expires_on"] = testCase.expiresOn + } + + accessToken, err := provider.getAccessToken(templateData) + So(err, ShouldBeNil) + So(accessToken, ShouldEqual, token["access_token"]) + + // getAccessToken should use internal cache + accessToken, err = provider.getAccessToken(templateData) + So(err, ShouldBeNil) + So(accessToken, ShouldEqual, token["access_token"]) + So(authCalls, ShouldEqual, 1) + + tokenCache.Lock() + v, ok := tokenCache.cache[provider.getAccessTokenCacheKey()] + tokenCache.Unlock() + + So(ok, ShouldBeTrue) + So(v.ExpiresOn.Unix(), ShouldEqual, testCase.expectedExpiresOn) + So(v.AccessToken, ShouldEqual, token["access_token"]) + }) + } + }) + + Convey("Should refetch token on expire", func() { + clearTokenCache() + // reset the httphandler counter + authCalls = 0 + + mockTimeNow(time.Now()) + defer resetTimeNow() + provider := newAccessTokenProvider(&models.DataSource{}, pluginRoute) + + token = map[string]interface{}{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "3600", + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + } + accessToken, err := provider.getAccessToken(templateData) + So(err, ShouldBeNil) + So(accessToken, ShouldEqual, token["access_token"]) + + mockTimeNow(timeNow().Add(3601 * time.Second)) + + accessToken, err = provider.getAccessToken(templateData) + So(err, ShouldBeNil) + So(accessToken, ShouldEqual, token["access_token"]) + So(authCalls, ShouldEqual, 2) + }) + }) +} + +func clearTokenCache() { + tokenCache.Lock() + defer tokenCache.Unlock() + tokenCache.cache = map[string]*jwtToken{} + token = map[string]interface{}{} +} + +func mockTimeNow(timeSeed time.Time) { + timeNow = func() time.Time { + return timeSeed + } +} + +func resetTimeNow() { + timeNow = time.Now }