mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Fix: Refresh token when id_token is expired (#79569)
* Fix: Refresh token when id_token is expired * add id_token comparison * Fix wire * Use userID as cache key * Apply suggestions from code review --------- Co-authored-by: linoman <2051016+linoman@users.noreply.github.com> Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
parent
62806e8f8c
commit
596e828150
@ -3,26 +3,20 @@ package sync
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||||
return &OAuthTokenSync{
|
return &OAuthTokenSync{
|
||||||
log.New("oauth_token.sync"),
|
log.New("oauth_token.sync"),
|
||||||
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
|
||||||
service,
|
service,
|
||||||
sessionService,
|
sessionService,
|
||||||
socialService,
|
socialService,
|
||||||
@ -31,12 +25,11 @@ func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthTokenSync struct {
|
type OAuthTokenSync struct {
|
||||||
log log.Logger
|
log log.Logger
|
||||||
cache *localcache.CacheService
|
service oauthtoken.OAuthTokenService
|
||||||
service oauthtoken.OAuthTokenService
|
sessionService auth.UserTokenService
|
||||||
sessionService auth.UserTokenService
|
socialService social.Service
|
||||||
socialService social.Service
|
singleflightGroup *singleflight.Group
|
||||||
sf *singleflight.Group
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||||
@ -51,71 +44,14 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we recently have performed this it would be cached, so we can skip the hook
|
_, err, _ := s.singleflightGroup.Do(identity.ID, func() (interface{}, error) {
|
||||||
if _, ok := s.cache.Get(identity.ID); ok {
|
|
||||||
s.log.FromContext(ctx).Debug("OAuth token check is cached", "id", identity.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
token, exists, err := s.service.HasOAuthEntry(ctx, identity)
|
|
||||||
// user is not authenticated through oauth so skip further checks
|
|
||||||
if !exists {
|
|
||||||
if err != nil {
|
|
||||||
s.log.FromContext(ctx).Error("Failed to fetch oauth entry", "id", identity.ID, "error", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
|
||||||
if err != nil {
|
|
||||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// token has no expire time configured, so we don't have to refresh it
|
|
||||||
if token.OAuthExpiry.IsZero() {
|
|
||||||
s.log.FromContext(ctx).Debug("Access token without expiry", "id", identity.ID)
|
|
||||||
// cache the token check, so we don't perform it on every request
|
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the token's auth provider (f.e. azuread)
|
|
||||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
|
||||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
|
||||||
if currentOAuthInfo == nil {
|
|
||||||
s.log.Warn("OAuth provider not found", "provider", provider)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
|
||||||
if !currentOAuthInfo.UseRefreshToken {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
accessTokenExpires, hasAccessTokenExpired := getExpiryWithSkew(token.OAuthExpiry)
|
|
||||||
|
|
||||||
hasIdTokenExpired := false
|
|
||||||
idTokenExpires := time.Time{}
|
|
||||||
|
|
||||||
if !idTokenExpiry.IsZero() {
|
|
||||||
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExpiry)
|
|
||||||
}
|
|
||||||
// token has not expired, so we don't have to refresh it
|
|
||||||
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
|
||||||
s.log.FromContext(ctx).Debug("Access and id token has not expired yet", "id", identity.ID)
|
|
||||||
// cache the token check, so we don't perform it on every request
|
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err, _ = s.sf.Do(identity.ID, func() (interface{}, error) {
|
|
||||||
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
|
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
|
||||||
|
|
||||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if refreshErr := s.service.TryTokenRefresh(updateCtx, token); refreshErr != nil {
|
if refreshErr := s.service.TryTokenRefresh(updateCtx, identity); refreshErr != nil {
|
||||||
if errors.Is(refreshErr, context.Canceled) {
|
if errors.Is(refreshErr, context.Canceled) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -153,56 +89,3 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
|
||||||
|
|
||||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
|
||||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
|
||||||
return maxOAuthTokenCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
min := func(a, b time.Duration) time.Duration {
|
|
||||||
if a <= b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
|
||||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
|
||||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
|
||||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
|
||||||
if token.OAuthIdToken == "" {
|
|
||||||
return time.Time{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
|
||||||
if err != nil {
|
|
||||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Claims struct {
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
var claims Claims
|
|
||||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
|
||||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Unix(claims.Exp, 0), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
|
|
||||||
adjustedExpiry = expiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
|
||||||
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
@ -2,18 +2,13 @@ package sync
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||||
@ -45,45 +40,17 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
desc: "should skip sync when identity is not a user",
|
desc: "should skip sync when identity is not a user",
|
||||||
identity: &authn.Identity{ID: "service-account:1"},
|
identity: &authn.Identity{ID: "service-account:1"},
|
||||||
|
expectTryRefreshTokenCalled: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||||
identity: &authn.Identity{ID: "user:1"},
|
identity: &authn.Identity{ID: "user:1"},
|
||||||
|
expectTryRefreshTokenCalled: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
desc: "should invalidate access token and session token if token refresh fails",
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync for when access token don't have expire time",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when access token has no expired yet",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when access token has no expired yet",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should refresh access token when it has expired",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectTryRefreshTokenCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
expectHasEntryCalled: true,
|
expectHasEntryCalled: true,
|
||||||
expectedTryRefreshErr: errors.New("some err"),
|
expectedTryRefreshErr: errors.New("some err"),
|
||||||
@ -92,21 +59,27 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
expectRevokeTokenCalled: true,
|
expectRevokeTokenCalled: true,
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
expectedErr: authn.ErrExpiredAccessToken,
|
expectedErr: authn.ErrExpiredAccessToken,
|
||||||
}, {
|
|
||||||
desc: "should skip sync when use_refresh_token is disabled",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectTryRefreshTokenCalled: false,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
||||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should refresh access token when ID token has expired",
|
desc: "should refresh the token successfully",
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
expectHasEntryCalled: true,
|
expectHasEntryCalled: false,
|
||||||
expectTryRefreshTokenCalled: true,
|
expectTryRefreshTokenCalled: true,
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
expectInvalidateOauthTokensCalled: false,
|
||||||
|
expectRevokeTokenCalled: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectInvalidateOauthTokensCalled: false,
|
||||||
|
expectRevokeTokenCalled: false,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
expectedTryRefreshErr: errors.New("some err"),
|
||||||
|
},
|
||||||
|
|
||||||
|
// TODO: address coverage of oauthtoken sync
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -127,7 +100,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
invalidateTokensCalled = true
|
invalidateTokensCalled = true
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
|
||||||
tryRefreshCalled = true
|
tryRefreshCalled = true
|
||||||
return tt.expectedTryRefreshErr
|
return tt.expectedTryRefreshErr
|
||||||
},
|
},
|
||||||
@ -151,12 +124,11 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sync := &OAuthTokenSync{
|
sync := &OAuthTokenSync{
|
||||||
log: log.NewNopLogger(),
|
log: log.NewNopLogger(),
|
||||||
cache: localcache.New(0, 0),
|
service: service,
|
||||||
service: service,
|
sessionService: sessionService,
|
||||||
sessionService: sessionService,
|
socialService: socialService,
|
||||||
socialService: socialService,
|
singleflightGroup: new(singleflight.Group),
|
||||||
sf: new(singleflight.Group),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
||||||
@ -168,93 +140,3 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
|
||||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
|
||||||
type Header struct {
|
|
||||||
Kid string `json:"kid"`
|
|
||||||
Alg string `json:"alg"`
|
|
||||||
}
|
|
||||||
type Payload struct {
|
|
||||||
Iss string `json:"iss"`
|
|
||||||
Sub string `json:"sub"`
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
u := expiryDate.UTC().Unix()
|
|
||||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
fakeSignature := []byte("6ICJm")
|
|
||||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
|
||||||
defaultTime := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accessTokenExpiry time.Time
|
|
||||||
idTokenExpiry time.Time
|
|
||||||
want time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
|
||||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
|
||||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -7,10 +7,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
@ -29,28 +31,33 @@ var (
|
|||||||
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
|
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Cfg *setting.Cfg
|
Cfg *setting.Cfg
|
||||||
SocialService social.Service
|
SocialService social.Service
|
||||||
AuthInfoService login.AuthInfoService
|
AuthInfoService login.AuthInfoService
|
||||||
singleFlightGroup *singleflight.Group
|
singleFlightGroup *singleflight.Group
|
||||||
|
cache *localcache.CacheService
|
||||||
|
|
||||||
tokenRefreshDuration *prometheus.HistogramVec
|
tokenRefreshDuration *prometheus.HistogramVec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:generate mockery --name OAuthTokenService --structname MockService --outpkg oauthtokentest --filename service_mock.go --output ./oauthtokentest/
|
||||||
type OAuthTokenService interface {
|
type OAuthTokenService interface {
|
||||||
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
|
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
|
||||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||||
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
||||||
TryTokenRefresh(context.Context, *login.UserAuth) error
|
TryTokenRefresh(context.Context, identity.Requester) error
|
||||||
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
|
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
|
AuthInfoService: authInfoService,
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
SocialService: socialService,
|
SocialService: socialService,
|
||||||
AuthInfoService: authInfoService,
|
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||||
singleFlightGroup: new(singleflight.Group),
|
singleFlightGroup: new(singleflight.Group),
|
||||||
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
|
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
|
||||||
}
|
}
|
||||||
@ -58,36 +65,12 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
|
|||||||
|
|
||||||
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
|
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
|
||||||
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||||
if usr == nil || usr.IsNil() {
|
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
|
||||||
// No user, therefore no token
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace, id := usr.GetNamespacedID()
|
token, err := o.tryGetOrRefreshOAuthToken(ctx, authInfo)
|
||||||
if namespace != identity.NamespaceUser {
|
|
||||||
// Not a user, therefore no token.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := identity.IntIdentifier(namespace, id)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
|
||||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, user.ErrUserNotFound) {
|
|
||||||
// Not necessarily an error. User may be logged in another way.
|
|
||||||
logger.Debug("No oauth token for user found", "userId", userID, "username", usr.GetLogin())
|
|
||||||
} else {
|
|
||||||
logger.Error("Failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||||
return buildOAuthTokenFromAuthInfo(authInfo)
|
return buildOAuthTokenFromAuthInfo(authInfo)
|
||||||
@ -119,6 +102,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
|||||||
|
|
||||||
userID, err := identity.IntIdentifier(namespace, id)
|
userID, err := identity.IntIdentifier(namespace, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,6 +111,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, user.ErrUserNotFound) {
|
if errors.Is(err, user.ErrUserNotFound) {
|
||||||
// Not necessarily an error. User may be logged in another way.
|
// Not necessarily an error. User may be logged in another way.
|
||||||
|
logger.Debug("No oauth token found for user", "userId", userID, "username", usr.GetLogin())
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
logger.Error("Failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
logger.Error("Failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
||||||
@ -140,13 +125,72 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
|||||||
|
|
||||||
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||||
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
|
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
|
||||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
|
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
|
||||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
|
if usr == nil || usr.IsNil() {
|
||||||
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (any, error) {
|
logger.Warn("Can only refresh OAuth tokens for existing users", "user", "nil")
|
||||||
|
// Not user, no token.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace, id := usr.GetNamespacedID()
|
||||||
|
if namespace != identity.NamespaceUser {
|
||||||
|
// Not a user, therefore no token.
|
||||||
|
logger.Warn("Can only refresh OAuth tokens for users", "namespace", namespace, "userId", id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := identity.IntIdentifier(namespace, id)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
|
||||||
|
if _, ok := o.cache.Get(lockKey); ok {
|
||||||
|
logger.Debug("Expiration check has been cached, no need to refresh", "userID", userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err, _ = o.singleFlightGroup.Do(lockKey, func() (any, error) {
|
||||||
logger.Debug("Singleflight request for getting a new access token", "key", lockKey)
|
logger.Debug("Singleflight request for getting a new access token", "key", lockKey)
|
||||||
|
|
||||||
return o.tryGetOrRefreshAccessToken(ctx, usr)
|
authInfo, exists, err := o.HasOAuthEntry(ctx, usr)
|
||||||
|
if !exists {
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Failed to fetch oauth entry", "id", userID, "error", err)
|
||||||
|
} else {
|
||||||
|
// User is not logged in via OAuth no need to check
|
||||||
|
o.cache.Set(lockKey, struct{}{}, maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, needRefresh, ttl := needTokenRefresh(authInfo)
|
||||||
|
if !needRefresh {
|
||||||
|
o.cache.Set(lockKey, struct{}{}, ttl)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the token's auth provider (f.e. azuread)
|
||||||
|
provider := strings.TrimPrefix(authInfo.AuthModule, "oauth_")
|
||||||
|
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
|
||||||
|
if currentOAuthInfo == nil {
|
||||||
|
logger.Warn("OAuth provider not found", "provider", provider)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if refresh token handling is disabled for this provider, we can skip the refresh
|
||||||
|
if !currentOAuthInfo.UseRefreshToken {
|
||||||
|
logger.Debug("Skipping token refresh", "provider", provider)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.tryGetOrRefreshOAuthToken(ctx, authInfo)
|
||||||
})
|
})
|
||||||
|
// Silence ErrNoRefreshTokenFound
|
||||||
|
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,11 +239,23 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
|
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
|
||||||
|
key := getCheckCacheKey(usr.UserId)
|
||||||
|
if _, ok := o.cache.Get(key); ok {
|
||||||
|
logger.Debug("Expiration check has been cached", "userID", usr.UserId)
|
||||||
|
return buildOAuthTokenFromAuthInfo(usr), nil
|
||||||
|
}
|
||||||
|
|
||||||
if err := checkOAuthRefreshToken(usr); err != nil {
|
if err := checkOAuthRefreshToken(usr); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
persistedToken, refreshNeeded, ttl := needTokenRefresh(usr)
|
||||||
|
if !refreshNeeded {
|
||||||
|
o.cache.Set(key, struct{}{}, ttl)
|
||||||
|
return persistedToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
authProvider := usr.AuthModule
|
authProvider := usr.AuthModule
|
||||||
connect, err := o.SocialService.GetConnector(authProvider)
|
connect, err := o.SocialService.GetConnector(authProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -214,8 +270,6 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use
|
|||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||||
|
|
||||||
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
// TokenSource handles refreshing the token if it has expired
|
// TokenSource handles refreshing the token if it has expired
|
||||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||||
@ -278,8 +332,91 @@ func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus
|
|||||||
|
|
||||||
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
||||||
func tokensEq(t1, t2 *oauth2.Token) bool {
|
func tokensEq(t1, t2 *oauth2.Token) bool {
|
||||||
|
t1IdToken, ok1 := t1.Extra("id_token").(string)
|
||||||
|
t2IdToken, ok2 := t2.Extra("id_token").(string)
|
||||||
|
|
||||||
return t1.AccessToken == t2.AccessToken &&
|
return t1.AccessToken == t2.AccessToken &&
|
||||||
t1.RefreshToken == t2.RefreshToken &&
|
t1.RefreshToken == t2.RefreshToken &&
|
||||||
t1.Expiry.Equal(t2.Expiry) &&
|
t1.Expiry.Equal(t2.Expiry) &&
|
||||||
t1.TokenType == t2.TokenType
|
t1.TokenType == t2.TokenType &&
|
||||||
|
ok1 == ok2 &&
|
||||||
|
t1IdToken == t2IdToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func needTokenRefresh(usr *login.UserAuth) (*oauth2.Token, bool, time.Duration) {
|
||||||
|
var accessTokenExpires, idTokenExpires time.Time
|
||||||
|
var hasAccessTokenExpired, hasIdTokenExpired bool
|
||||||
|
|
||||||
|
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
||||||
|
idTokenExp, err := getIDTokenExpiry(usr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Could not get ID Token expiry", "error", err)
|
||||||
|
}
|
||||||
|
if !persistedToken.Expiry.IsZero() {
|
||||||
|
accessTokenExpires, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
|
||||||
|
}
|
||||||
|
if !idTokenExp.IsZero() {
|
||||||
|
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
|
||||||
|
}
|
||||||
|
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
||||||
|
logger.Debug("Neither access nor id token have expired yet", "id", usr.Id)
|
||||||
|
return persistedToken, false, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)
|
||||||
|
}
|
||||||
|
if hasIdTokenExpired {
|
||||||
|
// Force refreshing token when id token is expired
|
||||||
|
persistedToken.AccessToken = ""
|
||||||
|
}
|
||||||
|
return persistedToken, true, time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCheckCacheKey(usrID int64) string {
|
||||||
|
return fmt.Sprintf("token-check-%d", usrID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||||
|
min := maxOAuthTokenCacheTTL
|
||||||
|
if !accessTokenExpiry.IsZero() {
|
||||||
|
d := time.Until(accessTokenExpiry)
|
||||||
|
if d < min {
|
||||||
|
min = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !idTokenExpiry.IsZero() {
|
||||||
|
d := time.Until(idTokenExpiry)
|
||||||
|
if d < min {
|
||||||
|
min = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||||
|
return maxOAuthTokenCacheTTL
|
||||||
|
}
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
|
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||||
|
func getIDTokenExpiry(usr *login.UserAuth) (time.Time, error) {
|
||||||
|
if usr.OAuthIdToken == "" {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedToken, err := jwt.ParseSigned(usr.OAuthIdToken)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
var claims Claims
|
||||||
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Unix(claims.Exp, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
|
||||||
|
adjustedExpiry = expiry.Round(0).Add(-ExpiryDelta)
|
||||||
|
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
@ -10,20 +10,26 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||||
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
|
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/authinfotest"
|
||||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||||
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
|
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
|
||||||
|
|
||||||
func TestService_HasOAuthEntry(t *testing.T) {
|
func TestService_HasOAuthEntry(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -69,10 +75,10 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "returns true when the auth entry is found",
|
name: "returns true when the auth entry is found",
|
||||||
user: &user.SignedInUser{UserID: 1},
|
user: &user.SignedInUser{UserID: 1},
|
||||||
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
|
want: &login.UserAuth{AuthModule: login.GenericOAuthModule},
|
||||||
wantExist: true,
|
wantExist: true,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
getAuthInfoUser: login.UserAuth{AuthModule: "oauth_generic_oauth"},
|
getAuthInfoUser: login.UserAuth{AuthModule: login.GenericOAuthModule},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@ -96,152 +102,26 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
|
|
||||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
|
||||||
ctx := context.Background()
|
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "testrefresh",
|
|
||||||
Expiry: time.Now(),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
usr := &login.UserAuth{
|
|
||||||
AuthModule: "oauth_generic_oauth",
|
|
||||||
OAuthAccessToken: token.AccessToken,
|
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
|
||||||
OAuthExpiry: token.Expiry,
|
|
||||||
OAuthTokenType: token.TokenType,
|
|
||||||
}
|
|
||||||
|
|
||||||
authInfoStore.ExpectedOAuth = usr
|
|
||||||
|
|
||||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
|
||||||
|
|
||||||
err := srv.TryTokenRefresh(ctx, usr)
|
|
||||||
require.Nil(t, err)
|
|
||||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
|
||||||
|
|
||||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
|
|
||||||
resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// User's token data had not been updated
|
|
||||||
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
|
|
||||||
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
|
|
||||||
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
|
|
||||||
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
|
|
||||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
|
||||||
ctx := context.Background()
|
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "",
|
|
||||||
Expiry: time.Now().Add(-time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
usr := &login.UserAuth{
|
|
||||||
AuthModule: "oauth_generic_oauth",
|
|
||||||
OAuthAccessToken: token.AccessToken,
|
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
|
||||||
OAuthExpiry: token.Expiry,
|
|
||||||
OAuthTokenType: token.TokenType,
|
|
||||||
}
|
|
||||||
|
|
||||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
|
||||||
|
|
||||||
err := srv.TryTokenRefresh(ctx, usr)
|
|
||||||
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
|
|
||||||
|
|
||||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
|
|
||||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
|
||||||
ctx := context.Background()
|
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "testrefresh",
|
|
||||||
Expiry: time.Now().Add(-time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
|
|
||||||
newToken := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess_new",
|
|
||||||
RefreshToken: "testrefresh_new",
|
|
||||||
Expiry: time.Now().Add(time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
|
|
||||||
usr := &login.UserAuth{
|
|
||||||
AuthModule: "oauth_generic_oauth",
|
|
||||||
UserId: 1,
|
|
||||||
AuthId: "test",
|
|
||||||
OAuthAccessToken: token.AccessToken,
|
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
|
||||||
OAuthExpiry: token.Expiry,
|
|
||||||
OAuthTokenType: token.TokenType,
|
|
||||||
}
|
|
||||||
|
|
||||||
authInfoStore.ExpectedOAuth = usr
|
|
||||||
|
|
||||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
|
|
||||||
|
|
||||||
err := srv.TryTokenRefresh(ctx, usr)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
|
||||||
|
|
||||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
|
|
||||||
authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
|
||||||
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
|
||||||
assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken)
|
|
||||||
assert.Equal(t, authInfo.OAuthExpiry, newToken.Expiry)
|
|
||||||
assert.Equal(t, authInfo.OAuthRefreshToken, newToken.RefreshToken)
|
|
||||||
assert.Equal(t, authInfo.OAuthTokenType, newToken.TokenType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
|
|
||||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
|
||||||
ctx := context.Background()
|
|
||||||
token := &oauth2.Token{}
|
|
||||||
usr := &login.UserAuth{
|
|
||||||
AuthModule: "auth.saml",
|
|
||||||
}
|
|
||||||
|
|
||||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
|
||||||
|
|
||||||
err := srv.TryTokenRefresh(ctx, usr)
|
|
||||||
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
|
|
||||||
|
|
||||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
|
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
socialConnector := &socialtest.MockSocialConnector{}
|
socialConnector := &socialtest.MockSocialConnector{}
|
||||||
socialService := &socialtest.FakeSocialService{
|
socialService := &socialtest.FakeSocialService{
|
||||||
ExpectedConnector: socialConnector,
|
ExpectedConnector: socialConnector,
|
||||||
|
ExpectedAuthInfoProvider: &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfoStore := &FakeAuthInfoStore{}
|
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
|
||||||
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(),
|
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
|
||||||
secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
|
|
||||||
return &Service{
|
return &Service{
|
||||||
Cfg: setting.NewCfg(),
|
Cfg: setting.NewCfg(),
|
||||||
SocialService: socialService,
|
SocialService: socialService,
|
||||||
AuthInfoService: authInfoService,
|
AuthInfoService: authInfoService,
|
||||||
singleFlightGroup: &singleflight.Group{},
|
singleFlightGroup: &singleflight.Group{},
|
||||||
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||||
|
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||||
}, authInfoStore, socialConnector
|
}, authInfoStore, socialConnector
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,3 +150,461 @@ func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.Updat
|
|||||||
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
||||||
return f.ExpectedError
|
return f.ExpectedError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestService_TryTokenRefresh(t *testing.T) {
|
||||||
|
type environment struct {
|
||||||
|
authInfoService *authinfotest.FakeService
|
||||||
|
cache *localcache.CacheService
|
||||||
|
identity identity.Requester
|
||||||
|
socialConnector *socialtest.MockSocialConnector
|
||||||
|
socialService *socialtest.FakeSocialService
|
||||||
|
|
||||||
|
service *Service
|
||||||
|
}
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
expectedErr error
|
||||||
|
setup func(env *environment)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should skip sync when identity is nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when identity is not a user",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "service-account:1"}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:invalidIdentifierFormat"}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh since the token is still valid",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.identity = &authn.Identity{
|
||||||
|
AuthenticatedBy: login.GenericOAuthModule,
|
||||||
|
ID: "user:1234",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh if the expiration check has already been cached",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedError = errors.New("some error")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh if the user doesn't have an oauth entry",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.SAMLAuthModule,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should do token refresh if access token or id token have not expired yet",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh when no oauth provider was found",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthIdToken: EXPIRED_JWT,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthIdToken: EXPIRED_JWT,
|
||||||
|
}
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: false,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh when there is no refresh token",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthIdToken: EXPIRED_JWT,
|
||||||
|
OAuthRefreshToken: "",
|
||||||
|
}
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should do token refresh when the token is expired",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
AuthId: "subject",
|
||||||
|
UserId: 1,
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
}
|
||||||
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should refresh token when the id token is expired",
|
||||||
|
setup: func(env *environment) {
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
env.identity = &authn.Identity{ID: "user:1234"}
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
AuthId: "subject",
|
||||||
|
UserId: 1,
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
OAuthIdToken: EXPIRED_JWT,
|
||||||
|
}
|
||||||
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
socialConnector := &socialtest.MockSocialConnector{}
|
||||||
|
|
||||||
|
env := environment{
|
||||||
|
authInfoService: &authinfotest.FakeService{},
|
||||||
|
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||||
|
socialConnector: socialConnector,
|
||||||
|
socialService: &socialtest.FakeSocialService{
|
||||||
|
ExpectedConnector: socialConnector,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.setup != nil {
|
||||||
|
tt.setup(&env)
|
||||||
|
}
|
||||||
|
|
||||||
|
env.service = &Service{
|
||||||
|
AuthInfoService: env.authInfoService,
|
||||||
|
Cfg: setting.NewCfg(),
|
||||||
|
cache: env.cache,
|
||||||
|
singleFlightGroup: &singleflight.Group{},
|
||||||
|
SocialService: env.socialService,
|
||||||
|
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||||
|
}
|
||||||
|
|
||||||
|
// token refresh
|
||||||
|
err := env.service.TryTokenRefresh(context.Background(), env.identity)
|
||||||
|
|
||||||
|
// test and validations
|
||||||
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
socialConnector.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||||
|
defaultTime := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessTokenExpiry time.Time
|
||||||
|
idTokenExpiry time.Time
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
usr *login.UserAuth
|
||||||
|
expectedTokenRefreshFlag bool
|
||||||
|
expectedTokenDuration time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should not need token refresh when token has no expiration date",
|
||||||
|
usr: &login.UserAuth{},
|
||||||
|
expectedTokenRefreshFlag: false,
|
||||||
|
expectedTokenDuration: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
OAuthIdToken: "invalid_jwt_format",
|
||||||
|
},
|
||||||
|
expectedTokenRefreshFlag: false,
|
||||||
|
expectedTokenDuration: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should flag token refresh with id token is expired",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
OAuthIdToken: EXPIRED_JWT,
|
||||||
|
},
|
||||||
|
expectedTokenRefreshFlag: true,
|
||||||
|
expectedTokenDuration: time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should flag token refresh when expiry date is zero",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
OAuthExpiry: time.Unix(0, 0),
|
||||||
|
},
|
||||||
|
expectedTokenRefreshFlag: true,
|
||||||
|
expectedTokenDuration: time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
|
||||||
|
|
||||||
|
assert.NotNil(t, token)
|
||||||
|
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
|
||||||
|
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
|
||||||
|
timeNow := time.Now()
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "oauth_access_token",
|
||||||
|
RefreshToken: "refresh_token_found",
|
||||||
|
Expiry: timeNow,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
type environment struct {
|
||||||
|
authInfoService *authinfotest.FakeService
|
||||||
|
cache *localcache.CacheService
|
||||||
|
socialConnector *socialtest.MockSocialConnector
|
||||||
|
socialService *socialtest.FakeSocialService
|
||||||
|
|
||||||
|
service *Service
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
expectedErr error
|
||||||
|
expectedToken *oauth2.Token
|
||||||
|
usr *login.UserAuth
|
||||||
|
setup func(env *environment)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "should find and retrieve token from cache",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
UserId: int64(1234),
|
||||||
|
OAuthAccessToken: "new_access_token",
|
||||||
|
OAuthExpiry: timeNow,
|
||||||
|
},
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.cache.Set("token-check-1234", token, 1*time.Minute)
|
||||||
|
},
|
||||||
|
expectedToken: &oauth2.Token{
|
||||||
|
AccessToken: "new_access_token",
|
||||||
|
Expiry: timeNow,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
UserId: int64(1234),
|
||||||
|
AuthModule: login.SAMLAuthModule,
|
||||||
|
},
|
||||||
|
expectedErr: ErrNotAnOAuthProvider,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
UserId: int64(1234),
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
},
|
||||||
|
expectedErr: ErrNoRefreshTokenFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should not refresh token if the token is not expired",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
UserId: int64(1234),
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: timeNow.Add(time.Hour),
|
||||||
|
OAuthTokenType: "Bearer",
|
||||||
|
},
|
||||||
|
expectedToken: &oauth2.Token{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
Expiry: timeNow.Add(time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should update saved token if the user auth has new access/refresh tokens",
|
||||||
|
usr: &login.UserAuth{
|
||||||
|
UserId: int64(1234),
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthAccessToken: "new_oauth_access_token",
|
||||||
|
OAuthRefreshToken: "new_refresh_token_found",
|
||||||
|
OAuthExpiry: timeNow,
|
||||||
|
},
|
||||||
|
expectedToken: &oauth2.Token{
|
||||||
|
AccessToken: "oauth_access_token",
|
||||||
|
RefreshToken: "refresh_token_found",
|
||||||
|
Expiry: timeNow,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
},
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
socialConnector := &socialtest.MockSocialConnector{}
|
||||||
|
|
||||||
|
env := environment{
|
||||||
|
authInfoService: &authinfotest.FakeService{},
|
||||||
|
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||||
|
socialConnector: socialConnector,
|
||||||
|
socialService: &socialtest.FakeSocialService{
|
||||||
|
ExpectedConnector: socialConnector,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.setup != nil {
|
||||||
|
tt.setup(&env)
|
||||||
|
}
|
||||||
|
|
||||||
|
env.service = &Service{
|
||||||
|
AuthInfoService: env.authInfoService,
|
||||||
|
Cfg: setting.NewCfg(),
|
||||||
|
cache: env.cache,
|
||||||
|
singleFlightGroup: &singleflight.Group{},
|
||||||
|
SocialService: env.socialService,
|
||||||
|
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
|
||||||
|
|
||||||
|
if tt.expectedToken != nil {
|
||||||
|
assert.Equal(t, tt.expectedToken, token)
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, tt.expectedErr, err)
|
||||||
|
socialConnector.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -15,7 +15,7 @@ type MockOauthTokenService struct {
|
|||||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||||
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
||||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
|
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
|
||||||
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
|
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||||
@ -46,7 +46,7 @@ func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
|
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
|
||||||
if m.TryTokenRefreshFunc != nil {
|
if m.TryTokenRefreshFunc != nil {
|
||||||
return m.TryTokenRefreshFunc(ctx, usr)
|
return m.TryTokenRefreshFunc(ctx, usr)
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.Use
|
|||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) TryTokenRefresh(context.Context, *login.UserAuth) error {
|
func (s *Service) TryTokenRefresh(context.Context, identity.Requester) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
146
pkg/services/oauthtoken/oauthtokentest/service_mock.go
Normal file
146
pkg/services/oauthtoken/oauthtokentest/service_mock.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
// Code generated by mockery v2.40.1. DO NOT EDIT.
|
||||||
|
|
||||||
|
package oauthtokentest
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
|
||||||
|
identity "github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
|
datasources "github.com/grafana/grafana/pkg/services/datasources"
|
||||||
|
|
||||||
|
login "github.com/grafana/grafana/pkg/services/login"
|
||||||
|
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
oauth2 "golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockService is an autogenerated mock type for the OAuthTokenService type
|
||||||
|
type MockService struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentOAuthToken provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockService) GetCurrentOAuthToken(_a0 context.Context, _a1 identity.Requester) *oauth2.Token {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for GetCurrentOAuthToken")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 *oauth2.Token
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *oauth2.Token); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*oauth2.Token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasOAuthEntry provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockService) HasOAuthEntry(_a0 context.Context, _a1 identity.Requester) (*login.UserAuth, bool, error) {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for HasOAuthEntry")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 *login.UserAuth
|
||||||
|
var r1 bool
|
||||||
|
var r2 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) (*login.UserAuth, bool, error)); ok {
|
||||||
|
return rf(_a0, _a1)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *login.UserAuth); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*login.UserAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, identity.Requester) bool); ok {
|
||||||
|
r1 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Get(1).(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(2).(func(context.Context, identity.Requester) error); ok {
|
||||||
|
r2 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r2 = ret.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1, r2
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateOAuthTokens provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockService) InvalidateOAuthTokens(_a0 context.Context, _a1 *login.UserAuth) error {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for InvalidateOAuthTokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *login.UserAuth) error); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOAuthPassThruEnabled provides a mock function with given fields: _a0
|
||||||
|
func (_m *MockService) IsOAuthPassThruEnabled(_a0 *datasources.DataSource) bool {
|
||||||
|
ret := _m.Called(_a0)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for IsOAuthPassThruEnabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 bool
|
||||||
|
if rf, ok := ret.Get(0).(func(*datasources.DataSource) bool); ok {
|
||||||
|
r0 = rf(_a0)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryTokenRefresh provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockService) TryTokenRefresh(_a0 context.Context, _a1 identity.Requester) error {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
if len(ret) == 0 {
|
||||||
|
panic("no return value specified for TryTokenRefresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) error); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
// The first argument is typically a *testing.T value.
|
||||||
|
func NewMockService(t interface {
|
||||||
|
mock.TestingT
|
||||||
|
Cleanup(func())
|
||||||
|
}) *MockService {
|
||||||
|
mock := &MockService{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user