Fix: Refresh token when id_token is expired (#79569)

* Fix: Refresh token when id_token is expired

* add id_token comparison

* Fix wire

* Use userID as cache key

* Apply suggestions from code review

---------

Co-authored-by: linoman <2051016+linoman@users.noreply.github.com>
Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
Gabriel MABILLE 2024-02-05 16:44:25 +01:00 committed by GitHub
parent 62806e8f8c
commit 596e828150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 834 additions and 448 deletions

View File

@ -3,26 +3,20 @@ package sync
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
)
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
return &OAuthTokenSync{
log.New("oauth_token.sync"),
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
service,
sessionService,
socialService,
@ -31,12 +25,11 @@ func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService
}
type OAuthTokenSync struct {
log log.Logger
cache *localcache.CacheService
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
socialService social.Service
sf *singleflight.Group
log log.Logger
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
socialService social.Service
singleflightGroup *singleflight.Group
}
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
@ -51,71 +44,14 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
// if we recently have performed this it would be cached, so we can skip the hook
if _, ok := s.cache.Get(identity.ID); ok {
s.log.FromContext(ctx).Debug("OAuth token check is cached", "id", identity.ID)
return nil
}
token, exists, err := s.service.HasOAuthEntry(ctx, identity)
// user is not authenticated through oauth so skip further checks
if !exists {
if err != nil {
s.log.FromContext(ctx).Error("Failed to fetch oauth entry", "id", identity.ID, "error", err)
}
return nil
}
idTokenExpiry, err := getIDTokenExpiry(token)
if err != nil {
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
}
// token has no expire time configured, so we don't have to refresh it
if token.OAuthExpiry.IsZero() {
s.log.FromContext(ctx).Debug("Access token without expiry", "id", identity.ID)
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
return nil
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
s.log.Warn("OAuth provider not found", "provider", provider)
return nil
}
// if refresh token handling is disabled for this provider, we can skip the hook
if !currentOAuthInfo.UseRefreshToken {
return nil
}
accessTokenExpires, hasAccessTokenExpired := getExpiryWithSkew(token.OAuthExpiry)
hasIdTokenExpired := false
idTokenExpires := time.Time{}
if !idTokenExpiry.IsZero() {
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExpiry)
}
// token has not expired, so we don't have to refresh it
if !hasAccessTokenExpired && !hasIdTokenExpired {
s.log.FromContext(ctx).Debug("Access and id token has not expired yet", "id", identity.ID)
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
return nil
}
_, err, _ = s.sf.Do(identity.ID, func() (interface{}, error) {
_, err, _ := s.singleflightGroup.Do(identity.ID, func() (interface{}, error) {
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if refreshErr := s.service.TryTokenRefresh(updateCtx, token); refreshErr != nil {
if refreshErr := s.service.TryTokenRefresh(updateCtx, identity); refreshErr != nil {
if errors.Is(refreshErr, context.Canceled) {
return nil, nil
}
@ -153,56 +89,3 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
const maxOAuthTokenCacheTTL = 10 * time.Minute
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
}
min := func(a, b time.Duration) time.Duration {
if a <= b {
return a
}
return b
}
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
}
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
}
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
if token.OAuthIdToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}
type Claims struct {
Exp int64 `json:"exp"`
}
var claims Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
}
return time.Unix(claims.Exp, 0), nil
}
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
adjustedExpiry = expiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasTokenExpired = adjustedExpiry.Before(time.Now())
return
}

View File

@ -2,18 +2,13 @@ package sync
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
@ -45,45 +40,17 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
tests := []testCase{
{
desc: "should skip sync when identity is not a user",
identity: &authn.Identity{ID: "service-account:1"},
desc: "should skip sync when identity is not a user",
identity: &authn.Identity{ID: "service-account:1"},
expectTryRefreshTokenCalled: false,
},
{
desc: "should skip sync when identity is a user but is not authenticated with session token",
identity: &authn.Identity{ID: "user:1"},
desc: "should skip sync when identity is a user but is not authenticated with session token",
identity: &authn.Identity{ID: "user:1"},
expectTryRefreshTokenCalled: false,
},
{
desc: "should skip sync when user has session but is not authenticated with oauth",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
},
{
desc: "should skip sync for when access token don't have expire time",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &login.UserAuth{},
},
{
desc: "should skip sync when access token has no expired yet",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
{
desc: "should skip sync when access token has no expired yet",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
{
desc: "should refresh access token when it has expired",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
},
{
desc: "should invalidate access token and session token if access token can't be refreshed",
desc: "should invalidate access token and session token if token refresh fails",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedTryRefreshErr: errors.New("some err"),
@ -92,21 +59,27 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectRevokeTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: authn.ErrExpiredAccessToken,
}, {
desc: "should skip sync when use_refresh_token is disabled",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: false,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
},
{
desc: "should refresh access token when ID token has expired",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
desc: "should refresh the token successfully",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: false,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
},
{
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
expectedTryRefreshErr: errors.New("some err"),
},
// TODO: address coverage of oauthtoken sync
}
for _, tt := range tests {
@ -127,7 +100,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
invalidateTokensCalled = true
return nil
},
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error {
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
tryRefreshCalled = true
return tt.expectedTryRefreshErr
},
@ -151,12 +124,11 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
}
sync := &OAuthTokenSync{
log: log.NewNopLogger(),
cache: localcache.New(0, 0),
service: service,
sessionService: sessionService,
socialService: socialService,
sf: new(singleflight.Group),
log: log.NewNopLogger(),
service: service,
sessionService: sessionService,
socialService: socialService,
singleflightGroup: new(singleflight.Group),
}
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
@ -168,93 +140,3 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
})
}
}
// fakeIDToken is used to create a fake invalid token to verify expiry logic
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
type Header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
type Payload struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Exp int64 `json:"exp"`
}
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
require.NoError(t, err)
u := expiryDate.UTC().Unix()
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
require.NoError(t, err)
fakeSignature := []byte("6ICJm")
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}

View File

@ -7,10 +7,12 @@ import (
"strings"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth/identity"
@ -29,28 +31,33 @@ var (
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
)
const maxOAuthTokenCacheTTL = 10 * time.Minute
type Service struct {
Cfg *setting.Cfg
SocialService social.Service
AuthInfoService login.AuthInfoService
singleFlightGroup *singleflight.Group
cache *localcache.CacheService
tokenRefreshDuration *prometheus.HistogramVec
}
//go:generate mockery --name OAuthTokenService --structname MockService --outpkg oauthtokentest --filename service_mock.go --output ./oauthtokentest/
type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
TryTokenRefresh(context.Context, *login.UserAuth) error
TryTokenRefresh(context.Context, identity.Requester) error
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
}
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
return &Service{
AuthInfoService: authInfoService,
Cfg: cfg,
SocialService: socialService,
AuthInfoService: authInfoService,
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
singleFlightGroup: new(singleflight.Group),
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
}
@ -58,36 +65,12 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
if usr == nil || usr.IsNil() {
// No user, therefore no token
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
if !ok {
return nil
}
namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
return nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
return nil
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("No oauth token for user found", "userId", userID, "username", usr.GetLogin())
} else {
logger.Error("Failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
}
return nil
}
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfo)
token, err := o.tryGetOrRefreshOAuthToken(ctx, authInfo)
if err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfo)
@ -119,6 +102,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
return nil, false, err
}
@ -127,6 +111,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("No oauth token found for user", "userId", userID, "username", usr.GetLogin())
return nil, false, nil
}
logger.Error("Failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
@ -140,13 +125,72 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
func (o *Service) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (any, error) {
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
if usr == nil || usr.IsNil() {
logger.Warn("Can only refresh OAuth tokens for existing users", "user", "nil")
// Not user, no token.
return nil
}
namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
logger.Warn("Can only refresh OAuth tokens for users", "namespace", namespace, "userId", id)
return nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
logger.Warn("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
return nil
}
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
if _, ok := o.cache.Get(lockKey); ok {
logger.Debug("Expiration check has been cached, no need to refresh", "userID", userID)
return nil
}
_, err, _ = o.singleFlightGroup.Do(lockKey, func() (any, error) {
logger.Debug("Singleflight request for getting a new access token", "key", lockKey)
return o.tryGetOrRefreshAccessToken(ctx, usr)
authInfo, exists, err := o.HasOAuthEntry(ctx, usr)
if !exists {
if err != nil {
logger.Debug("Failed to fetch oauth entry", "id", userID, "error", err)
} else {
// User is not logged in via OAuth no need to check
o.cache.Set(lockKey, struct{}{}, maxOAuthTokenCacheTTL)
}
return nil, nil
}
_, needRefresh, ttl := needTokenRefresh(authInfo)
if !needRefresh {
o.cache.Set(lockKey, struct{}{}, ttl)
return nil, nil
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(authInfo.AuthModule, "oauth_")
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
logger.Warn("OAuth provider not found", "provider", provider)
return nil, nil
}
// if refresh token handling is disabled for this provider, we can skip the refresh
if !currentOAuthInfo.UseRefreshToken {
logger.Debug("Skipping token refresh", "provider", provider)
return nil, nil
}
return o.tryGetOrRefreshOAuthToken(ctx, authInfo)
})
// Silence ErrNoRefreshTokenFound
if errors.Is(err, ErrNoRefreshTokenFound) {
return nil
}
return err
}
@ -195,11 +239,23 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth
})
}
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
key := getCheckCacheKey(usr.UserId)
if _, ok := o.cache.Get(key); ok {
logger.Debug("Expiration check has been cached", "userID", usr.UserId)
return buildOAuthTokenFromAuthInfo(usr), nil
}
if err := checkOAuthRefreshToken(usr); err != nil {
return nil, err
}
persistedToken, refreshNeeded, ttl := needTokenRefresh(usr)
if !refreshNeeded {
o.cache.Set(key, struct{}{}, ttl)
return persistedToken, nil
}
authProvider := usr.AuthModule
connect, err := o.SocialService.GetConnector(authProvider)
if err != nil {
@ -214,8 +270,6 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
persistedToken := buildOAuthTokenFromAuthInfo(usr)
start := time.Now()
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
@ -278,8 +332,91 @@ func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
func tokensEq(t1, t2 *oauth2.Token) bool {
t1IdToken, ok1 := t1.Extra("id_token").(string)
t2IdToken, ok2 := t2.Extra("id_token").(string)
return t1.AccessToken == t2.AccessToken &&
t1.RefreshToken == t2.RefreshToken &&
t1.Expiry.Equal(t2.Expiry) &&
t1.TokenType == t2.TokenType
t1.TokenType == t2.TokenType &&
ok1 == ok2 &&
t1IdToken == t2IdToken
}
func needTokenRefresh(usr *login.UserAuth) (*oauth2.Token, bool, time.Duration) {
var accessTokenExpires, idTokenExpires time.Time
var hasAccessTokenExpired, hasIdTokenExpired bool
persistedToken := buildOAuthTokenFromAuthInfo(usr)
idTokenExp, err := getIDTokenExpiry(usr)
if err != nil {
logger.Warn("Could not get ID Token expiry", "error", err)
}
if !persistedToken.Expiry.IsZero() {
accessTokenExpires, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
}
if !idTokenExp.IsZero() {
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
}
if !hasAccessTokenExpired && !hasIdTokenExpired {
logger.Debug("Neither access nor id token have expired yet", "id", usr.Id)
return persistedToken, false, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)
}
if hasIdTokenExpired {
// Force refreshing token when id token is expired
persistedToken.AccessToken = ""
}
return persistedToken, true, time.Second
}
func getCheckCacheKey(usrID int64) string {
return fmt.Sprintf("token-check-%d", usrID)
}
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
min := maxOAuthTokenCacheTTL
if !accessTokenExpiry.IsZero() {
d := time.Until(accessTokenExpiry)
if d < min {
min = d
}
}
if !idTokenExpiry.IsZero() {
d := time.Until(idTokenExpiry)
if d < min {
min = d
}
}
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
}
return min
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(usr *login.UserAuth) (time.Time, error) {
if usr.OAuthIdToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(usr.OAuthIdToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}
type Claims struct {
Exp int64 `json:"exp"`
}
var claims Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
}
return time.Unix(claims.Exp, 0), nil
}
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
adjustedExpiry = expiry.Round(0).Add(-ExpiryDelta)
hasTokenExpired = adjustedExpiry.Before(time.Now())
return
}

View File

@ -10,20 +10,26 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
func TestService_HasOAuthEntry(t *testing.T) {
testCases := []struct {
name string
@ -69,10 +75,10 @@ func TestService_HasOAuthEntry(t *testing.T) {
{
name: "returns true when the auth entry is found",
user: &user.SignedInUser{UserID: 1},
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
want: &login.UserAuth{AuthModule: login.GenericOAuthModule},
wantExist: true,
wantErr: false,
getAuthInfoUser: login.UserAuth{AuthModule: "oauth_generic_oauth"},
getAuthInfoUser: login.UserAuth{AuthModule: login.GenericOAuthModule},
},
}
for _, tc := range testCases {
@ -96,152 +102,26 @@ func TestService_HasOAuthEntry(t *testing.T) {
}
}
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now(),
TokenType: "Bearer",
}
usr := &login.UserAuth{
AuthModule: "oauth_generic_oauth",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
authInfoStore.ExpectedOAuth = usr
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
require.Nil(t, err)
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
require.Nil(t, err)
// User's token data had not been updated
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
}
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
srv, _, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
usr := &login.UserAuth{
AuthModule: "oauth_generic_oauth",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
assert.NotNil(t, err)
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
socialConnector.AssertNotCalled(t, "TokenSource")
}
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
newToken := &oauth2.Token{
AccessToken: "testaccess_new",
RefreshToken: "testrefresh_new",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
usr := &login.UserAuth{
AuthModule: "oauth_generic_oauth",
UserId: 1,
AuthId: "test",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
authInfoStore.ExpectedOAuth = usr
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
err := srv.TryTokenRefresh(ctx, usr)
require.Nil(t, err)
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
require.Nil(t, err)
// newToken should be returned after the .Token() call, therefore the User had to be updated
assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken)
assert.Equal(t, authInfo.OAuthExpiry, newToken.Expiry)
assert.Equal(t, authInfo.OAuthRefreshToken, newToken.RefreshToken)
assert.Equal(t, authInfo.OAuthTokenType, newToken.TokenType)
}
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
srv, _, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{}
usr := &login.UserAuth{
AuthModule: "auth.saml",
}
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
assert.NotNil(t, err)
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
socialConnector.AssertNotCalled(t, "TokenSource")
}
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
t.Helper()
socialConnector := &socialtest.MockSocialConnector{}
socialService := &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
ExpectedAuthInfoProvider: &social.OAuthInfo{
UseRefreshToken: true,
},
}
authInfoStore := &FakeAuthInfoStore{}
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(),
secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
}, authInfoStore, socialConnector
}
@ -270,3 +150,461 @@ func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.Updat
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func TestService_TryTokenRefresh(t *testing.T) {
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
identity identity.Requester
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
type testCase struct {
desc string
expectedErr error
setup func(env *environment)
}
tests := []testCase{
{
desc: "should skip sync when identity is nil",
},
{
desc: "should skip sync when identity is not a user",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "service-account:1"}
},
},
{
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:invalidIdentifierFormat"}
},
},
{
desc: "should skip token refresh since the token is still valid",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.identity = &authn.Identity{
AuthenticatedBy: login.GenericOAuthModule,
ID: "user:1234",
}
},
},
{
desc: "should skip token refresh if the expiration check has already been cached",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
},
},
{
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedError = errors.New("some error")
},
},
{
desc: "should skip token refresh if the user doesn't have an oauth entry",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.SAMLAuthModule,
}
},
},
{
desc: "should do token refresh if access token or id token have not expired yet",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
}
},
},
{
desc: "should skip token refresh when no oauth provider was found",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
},
},
{
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: false,
}
},
},
{
desc: "should skip token refresh when there is no refresh token",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
OAuthRefreshToken: "",
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
},
},
{
desc: "should do token refresh when the token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
{
desc: "should refresh token when the id token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
OAuthIdToken: EXPIRED_JWT,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
// token refresh
err := env.service.TryTokenRefresh(context.Background(), env.identity)
// test and validations
assert.ErrorIs(t, err, tt.expectedErr)
socialConnector.AssertExpectations(t)
})
}
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
tests := []struct {
name string
usr *login.UserAuth
expectedTokenRefreshFlag bool
expectedTokenDuration time.Duration
}{
{
name: "should not need token refresh when token has no expiration date",
usr: &login.UserAuth{},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
usr: &login.UserAuth{
OAuthIdToken: "invalid_jwt_format",
},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should flag token refresh with id token is expired",
usr: &login.UserAuth{
OAuthIdToken: EXPIRED_JWT,
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
{
name: "should flag token refresh when expiry date is zero",
usr: &login.UserAuth{
OAuthExpiry: time.Unix(0, 0),
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
})
}
}
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
timeNow := time.Now()
token := &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
}
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
tests := []struct {
desc string
expectedErr error
expectedToken *oauth2.Token
usr *login.UserAuth
setup func(env *environment)
}{
{
desc: "should find and retrieve token from cache",
usr: &login.UserAuth{
UserId: int64(1234),
OAuthAccessToken: "new_access_token",
OAuthExpiry: timeNow,
},
setup: func(env *environment) {
env.cache.Set("token-check-1234", token, 1*time.Minute)
},
expectedToken: &oauth2.Token{
AccessToken: "new_access_token",
Expiry: timeNow,
},
},
{
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.SAMLAuthModule,
},
expectedErr: ErrNotAnOAuthProvider,
},
{
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
},
expectedErr: ErrNoRefreshTokenFound,
},
{
desc: "should not refresh token if the token is not expired",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: timeNow.Add(time.Hour),
OAuthTokenType: "Bearer",
},
expectedToken: &oauth2.Token{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: timeNow.Add(time.Hour),
TokenType: "Bearer",
},
},
{
desc: "should update saved token if the user auth has new access/refresh tokens",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: "new_oauth_access_token",
OAuthRefreshToken: "new_refresh_token_found",
OAuthExpiry: timeNow,
},
expectedToken: &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
},
setup: func(env *environment) {
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
if tt.expectedToken != nil {
assert.Equal(t, tt.expectedToken, token)
}
assert.ErrorIs(t, tt.expectedErr, err)
socialConnector.AssertExpectations(t)
})
}
}

View File

@ -15,7 +15,7 @@ type MockOauthTokenService struct {
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) error
}
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
@ -46,7 +46,7 @@ func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *
return nil
}
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
if m.TryTokenRefreshFunc != nil {
return m.TryTokenRefreshFunc(ctx, usr)
}

View File

@ -33,7 +33,7 @@ func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.Use
return nil, false, nil
}
func (s *Service) TryTokenRefresh(context.Context, *login.UserAuth) error {
func (s *Service) TryTokenRefresh(context.Context, identity.Requester) error {
return nil
}

View File

@ -0,0 +1,146 @@
// Code generated by mockery v2.40.1. DO NOT EDIT.
package oauthtokentest
import (
context "context"
identity "github.com/grafana/grafana/pkg/services/auth/identity"
datasources "github.com/grafana/grafana/pkg/services/datasources"
login "github.com/grafana/grafana/pkg/services/login"
mock "github.com/stretchr/testify/mock"
oauth2 "golang.org/x/oauth2"
)
// MockService is an autogenerated mock type for the OAuthTokenService type
type MockService struct {
mock.Mock
}
// GetCurrentOAuthToken provides a mock function with given fields: _a0, _a1
func (_m *MockService) GetCurrentOAuthToken(_a0 context.Context, _a1 identity.Requester) *oauth2.Token {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for GetCurrentOAuthToken")
}
var r0 *oauth2.Token
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *oauth2.Token); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*oauth2.Token)
}
}
return r0
}
// HasOAuthEntry provides a mock function with given fields: _a0, _a1
func (_m *MockService) HasOAuthEntry(_a0 context.Context, _a1 identity.Requester) (*login.UserAuth, bool, error) {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for HasOAuthEntry")
}
var r0 *login.UserAuth
var r1 bool
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) (*login.UserAuth, bool, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *login.UserAuth); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*login.UserAuth)
}
}
if rf, ok := ret.Get(1).(func(context.Context, identity.Requester) bool); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Get(1).(bool)
}
if rf, ok := ret.Get(2).(func(context.Context, identity.Requester) error); ok {
r2 = rf(_a0, _a1)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// InvalidateOAuthTokens provides a mock function with given fields: _a0, _a1
func (_m *MockService) InvalidateOAuthTokens(_a0 context.Context, _a1 *login.UserAuth) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for InvalidateOAuthTokens")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *login.UserAuth) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsOAuthPassThruEnabled provides a mock function with given fields: _a0
func (_m *MockService) IsOAuthPassThruEnabled(_a0 *datasources.DataSource) bool {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for IsOAuthPassThruEnabled")
}
var r0 bool
if rf, ok := ret.Get(0).(func(*datasources.DataSource) bool); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// TryTokenRefresh provides a mock function with given fields: _a0, _a1
func (_m *MockService) TryTokenRefresh(_a0 context.Context, _a1 identity.Requester) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for TryTokenRefresh")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockService(t interface {
mock.TestingT
Cleanup(func())
}) *MockService {
mock := &MockService{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}