mirror of
https://github.com/grafana/grafana.git
synced 2025-01-09 23:53:25 -06:00
Fix: Refresh token when id_token is expired (#79569)
* Fix: Refresh token when id_token is expired * add id_token comparison * Fix wire * Use userID as cache key * Apply suggestions from code review --------- Co-authored-by: linoman <2051016+linoman@users.noreply.github.com> Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
parent
62806e8f8c
commit
596e828150
@ -3,26 +3,20 @@ package sync
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
)
|
||||
|
||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||
return &OAuthTokenSync{
|
||||
log.New("oauth_token.sync"),
|
||||
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
service,
|
||||
sessionService,
|
||||
socialService,
|
||||
@ -31,12 +25,11 @@ func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService
|
||||
}
|
||||
|
||||
type OAuthTokenSync struct {
|
||||
log log.Logger
|
||||
cache *localcache.CacheService
|
||||
service oauthtoken.OAuthTokenService
|
||||
sessionService auth.UserTokenService
|
||||
socialService social.Service
|
||||
sf *singleflight.Group
|
||||
log log.Logger
|
||||
service oauthtoken.OAuthTokenService
|
||||
sessionService auth.UserTokenService
|
||||
socialService social.Service
|
||||
singleflightGroup *singleflight.Group
|
||||
}
|
||||
|
||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||
@ -51,71 +44,14 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
// if we recently have performed this it would be cached, so we can skip the hook
|
||||
if _, ok := s.cache.Get(identity.ID); ok {
|
||||
s.log.FromContext(ctx).Debug("OAuth token check is cached", "id", identity.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
token, exists, err := s.service.HasOAuthEntry(ctx, identity)
|
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists {
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to fetch oauth entry", "id", identity.ID, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() {
|
||||
s.log.FromContext(ctx).Debug("Access token without expiry", "id", identity.ID)
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
||||
if currentOAuthInfo == nil {
|
||||
s.log.Warn("OAuth provider not found", "provider", provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
||||
if !currentOAuthInfo.UseRefreshToken {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessTokenExpires, hasAccessTokenExpired := getExpiryWithSkew(token.OAuthExpiry)
|
||||
|
||||
hasIdTokenExpired := false
|
||||
idTokenExpires := time.Time{}
|
||||
|
||||
if !idTokenExpiry.IsZero() {
|
||||
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExpiry)
|
||||
}
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
||||
s.log.FromContext(ctx).Debug("Access and id token has not expired yet", "id", identity.ID)
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err, _ = s.sf.Do(identity.ID, func() (interface{}, error) {
|
||||
_, err, _ := s.singleflightGroup.Do(identity.ID, func() (interface{}, error) {
|
||||
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
|
||||
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if refreshErr := s.service.TryTokenRefresh(updateCtx, token); refreshErr != nil {
|
||||
if refreshErr := s.service.TryTokenRefresh(updateCtx, identity); refreshErr != nil {
|
||||
if errors.Is(refreshErr, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
@ -153,56 +89,3 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return maxOAuthTokenCacheTTL
|
||||
}
|
||||
|
||||
min := func(a, b time.Duration) time.Duration {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||
if token.OAuthIdToken == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
var claims Claims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
|
||||
adjustedExpiry = expiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
||||
return
|
||||
}
|
||||
|
@ -2,18 +2,13 @@ package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||
@ -45,45 +40,17 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should skip sync when identity is not a user",
|
||||
identity: &authn.Identity{ID: "service-account:1"},
|
||||
desc: "should skip sync when identity is not a user",
|
||||
identity: &authn.Identity{ID: "service-account:1"},
|
||||
expectTryRefreshTokenCalled: false,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
expectTryRefreshTokenCalled: false,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync for when access token don't have expire time",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when it has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||
desc: "should invalidate access token and session token if token refresh fails",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedTryRefreshErr: errors.New("some err"),
|
||||
@ -92,21 +59,27 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
expectedErr: authn.ErrExpiredAccessToken,
|
||||
}, {
|
||||
desc: "should skip sync when use_refresh_token is disabled",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: false,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when ID token has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||
desc: "should refresh the token successfully",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: false,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectInvalidateOauthTokensCalled: false,
|
||||
expectRevokeTokenCalled: false,
|
||||
},
|
||||
{
|
||||
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectInvalidateOauthTokensCalled: false,
|
||||
expectRevokeTokenCalled: false,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
expectedTryRefreshErr: errors.New("some err"),
|
||||
},
|
||||
|
||||
// TODO: address coverage of oauthtoken sync
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -127,7 +100,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
invalidateTokensCalled = true
|
||||
return nil
|
||||
},
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
|
||||
tryRefreshCalled = true
|
||||
return tt.expectedTryRefreshErr
|
||||
},
|
||||
@ -151,12 +124,11 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
}
|
||||
|
||||
sync := &OAuthTokenSync{
|
||||
log: log.NewNopLogger(),
|
||||
cache: localcache.New(0, 0),
|
||||
service: service,
|
||||
sessionService: sessionService,
|
||||
socialService: socialService,
|
||||
sf: new(singleflight.Group),
|
||||
log: log.NewNopLogger(),
|
||||
service: service,
|
||||
sessionService: sessionService,
|
||||
socialService: socialService,
|
||||
singleflightGroup: new(singleflight.Group),
|
||||
}
|
||||
|
||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
||||
@ -168,93 +140,3 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||
type Header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
type Payload struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||
require.NoError(t, err)
|
||||
u := expiryDate.UTC().Unix()
|
||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeSignature := []byte("6ICJm")
|
||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||
defaultTime := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenExpiry time.Time
|
||||
idTokenExpiry time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: time.Time{},
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -7,10 +7,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
@ -29,28 +31,33 @@ var (
|
||||
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
|
||||
)
|
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||
|
||||
type Service struct {
|
||||
Cfg *setting.Cfg
|
||||
SocialService social.Service
|
||||
AuthInfoService login.AuthInfoService
|
||||
singleFlightGroup *singleflight.Group
|
||||
cache *localcache.CacheService
|
||||
|
||||
tokenRefreshDuration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
//go:generate mockery --name OAuthTokenService --structname MockService --outpkg oauthtokentest --filename service_mock.go --output ./oauthtokentest/
|
||||
type OAuthTokenService interface {
|
||||
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
|
||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
||||
TryTokenRefresh(context.Context, *login.UserAuth) error
|
||||
TryTokenRefresh(context.Context, identity.Requester) error
|
||||
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
||||
}
|
||||
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
|
||||
return &Service{
|
||||
AuthInfoService: authInfoService,
|
||||
Cfg: cfg,
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
singleFlightGroup: new(singleflight.Group),
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
|
||||
}
|
||||
@ -58,36 +65,12 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
|
||||
|
||||
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
|
||||
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||
if usr == nil || usr.IsNil() {
|
||||
// No user, therefore no token
|
||||
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
namespace, id := usr.GetNamespacedID()
|
||||
if namespace != identity.NamespaceUser {
|
||||
// Not a user, therefore no token.
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("No oauth token for user found", "userId", userID, "username", usr.GetLogin())
|
||||
} else {
|
||||
logger.Error("Failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfo)
|
||||
token, err := o.tryGetOrRefreshOAuthToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return buildOAuthTokenFromAuthInfo(authInfo)
|
||||
@ -119,6 +102,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
logger.Error("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@ -127,6 +111,7 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("No oauth token found for user", "userId", userID, "username", usr.GetLogin())
|
||||
return nil, false, nil
|
||||
}
|
||||
logger.Error("Failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
||||
@ -140,13 +125,72 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
||||
|
||||
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
|
||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
|
||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
|
||||
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (any, error) {
|
||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
|
||||
if usr == nil || usr.IsNil() {
|
||||
logger.Warn("Can only refresh OAuth tokens for existing users", "user", "nil")
|
||||
// Not user, no token.
|
||||
return nil
|
||||
}
|
||||
|
||||
namespace, id := usr.GetNamespacedID()
|
||||
if namespace != identity.NamespaceUser {
|
||||
// Not a user, therefore no token.
|
||||
logger.Warn("Can only refresh OAuth tokens for users", "namespace", namespace, "userId", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
|
||||
if _, ok := o.cache.Get(lockKey); ok {
|
||||
logger.Debug("Expiration check has been cached, no need to refresh", "userID", userID)
|
||||
return nil
|
||||
}
|
||||
_, err, _ = o.singleFlightGroup.Do(lockKey, func() (any, error) {
|
||||
logger.Debug("Singleflight request for getting a new access token", "key", lockKey)
|
||||
|
||||
return o.tryGetOrRefreshAccessToken(ctx, usr)
|
||||
authInfo, exists, err := o.HasOAuthEntry(ctx, usr)
|
||||
if !exists {
|
||||
if err != nil {
|
||||
logger.Debug("Failed to fetch oauth entry", "id", userID, "error", err)
|
||||
} else {
|
||||
// User is not logged in via OAuth no need to check
|
||||
o.cache.Set(lockKey, struct{}{}, maxOAuthTokenCacheTTL)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, needRefresh, ttl := needTokenRefresh(authInfo)
|
||||
if !needRefresh {
|
||||
o.cache.Set(lockKey, struct{}{}, ttl)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
provider := strings.TrimPrefix(authInfo.AuthModule, "oauth_")
|
||||
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
|
||||
if currentOAuthInfo == nil {
|
||||
logger.Warn("OAuth provider not found", "provider", provider)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// if refresh token handling is disabled for this provider, we can skip the refresh
|
||||
if !currentOAuthInfo.UseRefreshToken {
|
||||
logger.Debug("Skipping token refresh", "provider", provider)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return o.tryGetOrRefreshOAuthToken(ctx, authInfo)
|
||||
})
|
||||
// Silence ErrNoRefreshTokenFound
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -195,11 +239,23 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth
|
||||
})
|
||||
}
|
||||
|
||||
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
|
||||
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
|
||||
key := getCheckCacheKey(usr.UserId)
|
||||
if _, ok := o.cache.Get(key); ok {
|
||||
logger.Debug("Expiration check has been cached", "userID", usr.UserId)
|
||||
return buildOAuthTokenFromAuthInfo(usr), nil
|
||||
}
|
||||
|
||||
if err := checkOAuthRefreshToken(usr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistedToken, refreshNeeded, ttl := needTokenRefresh(usr)
|
||||
if !refreshNeeded {
|
||||
o.cache.Set(key, struct{}{}, ttl)
|
||||
return persistedToken, nil
|
||||
}
|
||||
|
||||
authProvider := usr.AuthModule
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
if err != nil {
|
||||
@ -214,8 +270,6 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
|
||||
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
||||
|
||||
start := time.Now()
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
@ -278,8 +332,91 @@ func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus
|
||||
|
||||
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
||||
func tokensEq(t1, t2 *oauth2.Token) bool {
|
||||
t1IdToken, ok1 := t1.Extra("id_token").(string)
|
||||
t2IdToken, ok2 := t2.Extra("id_token").(string)
|
||||
|
||||
return t1.AccessToken == t2.AccessToken &&
|
||||
t1.RefreshToken == t2.RefreshToken &&
|
||||
t1.Expiry.Equal(t2.Expiry) &&
|
||||
t1.TokenType == t2.TokenType
|
||||
t1.TokenType == t2.TokenType &&
|
||||
ok1 == ok2 &&
|
||||
t1IdToken == t2IdToken
|
||||
}
|
||||
|
||||
func needTokenRefresh(usr *login.UserAuth) (*oauth2.Token, bool, time.Duration) {
|
||||
var accessTokenExpires, idTokenExpires time.Time
|
||||
var hasAccessTokenExpired, hasIdTokenExpired bool
|
||||
|
||||
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
||||
idTokenExp, err := getIDTokenExpiry(usr)
|
||||
if err != nil {
|
||||
logger.Warn("Could not get ID Token expiry", "error", err)
|
||||
}
|
||||
if !persistedToken.Expiry.IsZero() {
|
||||
accessTokenExpires, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
|
||||
}
|
||||
if !idTokenExp.IsZero() {
|
||||
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
|
||||
}
|
||||
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
||||
logger.Debug("Neither access nor id token have expired yet", "id", usr.Id)
|
||||
return persistedToken, false, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)
|
||||
}
|
||||
if hasIdTokenExpired {
|
||||
// Force refreshing token when id token is expired
|
||||
persistedToken.AccessToken = ""
|
||||
}
|
||||
return persistedToken, true, time.Second
|
||||
}
|
||||
|
||||
func getCheckCacheKey(usrID int64) string {
|
||||
return fmt.Sprintf("token-check-%d", usrID)
|
||||
}
|
||||
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||
min := maxOAuthTokenCacheTTL
|
||||
if !accessTokenExpiry.IsZero() {
|
||||
d := time.Until(accessTokenExpiry)
|
||||
if d < min {
|
||||
min = d
|
||||
}
|
||||
}
|
||||
if !idTokenExpiry.IsZero() {
|
||||
d := time.Until(idTokenExpiry)
|
||||
if d < min {
|
||||
min = d
|
||||
}
|
||||
}
|
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return maxOAuthTokenCacheTTL
|
||||
}
|
||||
return min
|
||||
}
|
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(usr *login.UserAuth) (time.Time, error) {
|
||||
if usr.OAuthIdToken == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.ParseSigned(usr.OAuthIdToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
var claims Claims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
|
||||
adjustedExpiry = expiry.Round(0).Add(-ExpiryDelta)
|
||||
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
||||
return
|
||||
}
|
||||
|
@ -10,20 +10,26 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfotest"
|
||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
|
||||
|
||||
func TestService_HasOAuthEntry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -69,10 +75,10 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
{
|
||||
name: "returns true when the auth entry is found",
|
||||
user: &user.SignedInUser{UserID: 1},
|
||||
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||
want: &login.UserAuth{AuthModule: login.GenericOAuthModule},
|
||||
wantExist: true,
|
||||
wantErr: false,
|
||||
getAuthInfoUser: login.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||
getAuthInfoUser: login.UserAuth{AuthModule: login.GenericOAuthModule},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
@ -96,152 +102,26 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
|
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
usr := &login.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
authInfoStore.ExpectedOAuth = usr
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
require.Nil(t, err)
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
|
||||
resultUsr, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
require.Nil(t, err)
|
||||
|
||||
// User's token data had not been updated
|
||||
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
|
||||
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
|
||||
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
|
||||
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
|
||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
usr := &login.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
|
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
|
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
newToken := &oauth2.Token{
|
||||
AccessToken: "testaccess_new",
|
||||
RefreshToken: "testrefresh_new",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
usr := &login.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
UserId: 1,
|
||||
AuthId: "test",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
authInfoStore.ExpectedOAuth = usr
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
require.Nil(t, err)
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: 1}
|
||||
authInfo, err := srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
||||
assert.Equal(t, authInfo.OAuthAccessToken, newToken.AccessToken)
|
||||
assert.Equal(t, authInfo.OAuthExpiry, newToken.Expiry)
|
||||
assert.Equal(t, authInfo.OAuthRefreshToken, newToken.RefreshToken)
|
||||
assert.Equal(t, authInfo.OAuthTokenType, newToken.TokenType)
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
|
||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{}
|
||||
usr := &login.UserAuth{
|
||||
AuthModule: "auth.saml",
|
||||
}
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
|
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||
}
|
||||
|
||||
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
|
||||
t.Helper()
|
||||
|
||||
socialConnector := &socialtest.MockSocialConnector{}
|
||||
socialService := &socialtest.FakeSocialService{
|
||||
ExpectedConnector: socialConnector,
|
||||
ExpectedAuthInfoProvider: &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
},
|
||||
}
|
||||
|
||||
authInfoStore := &FakeAuthInfoStore{}
|
||||
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(),
|
||||
secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
|
||||
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
|
||||
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
|
||||
return &Service{
|
||||
Cfg: setting.NewCfg(),
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
}, authInfoStore, socialConnector
|
||||
}
|
||||
|
||||
@ -270,3 +150,461 @@ func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.Updat
|
||||
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh(t *testing.T) {
|
||||
type environment struct {
|
||||
authInfoService *authinfotest.FakeService
|
||||
cache *localcache.CacheService
|
||||
identity identity.Requester
|
||||
socialConnector *socialtest.MockSocialConnector
|
||||
socialService *socialtest.FakeSocialService
|
||||
|
||||
service *Service
|
||||
}
|
||||
type testCase struct {
|
||||
desc string
|
||||
expectedErr error
|
||||
setup func(env *environment)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should skip sync when identity is nil",
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when identity is not a user",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "service-account:1"}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:invalidIdentifierFormat"}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh since the token is still valid",
|
||||
setup: func(env *environment) {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
env.identity = &authn.Identity{
|
||||
AuthenticatedBy: login.GenericOAuthModule,
|
||||
ID: "user:1234",
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh if the expiration check has already been cached",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedError = errors.New("some error")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh if the user doesn't have an oauth entry",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.SAMLAuthModule,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should do token refresh if access token or id token have not expired yet",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh when no oauth provider was found",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthIdToken: EXPIRED_JWT,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthIdToken: EXPIRED_JWT,
|
||||
}
|
||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||
UseRefreshToken: false,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should skip token refresh when there is no refresh token",
|
||||
setup: func(env *environment) {
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthIdToken: EXPIRED_JWT,
|
||||
OAuthRefreshToken: "",
|
||||
}
|
||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should do token refresh when the token is expired",
|
||||
setup: func(env *environment) {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
AuthId: "subject",
|
||||
UserId: 1,
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should refresh token when the id token is expired",
|
||||
setup: func(env *environment) {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
env.identity = &authn.Identity{ID: "user:1234"}
|
||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
}
|
||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
AuthId: "subject",
|
||||
UserId: 1,
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
OAuthIdToken: EXPIRED_JWT,
|
||||
}
|
||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
socialConnector := &socialtest.MockSocialConnector{}
|
||||
|
||||
env := environment{
|
||||
authInfoService: &authinfotest.FakeService{},
|
||||
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
socialConnector: socialConnector,
|
||||
socialService: &socialtest.FakeSocialService{
|
||||
ExpectedConnector: socialConnector,
|
||||
},
|
||||
}
|
||||
|
||||
if tt.setup != nil {
|
||||
tt.setup(&env)
|
||||
}
|
||||
|
||||
env.service = &Service{
|
||||
AuthInfoService: env.authInfoService,
|
||||
Cfg: setting.NewCfg(),
|
||||
cache: env.cache,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
SocialService: env.socialService,
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||
}
|
||||
|
||||
// token refresh
|
||||
err := env.service.TryTokenRefresh(context.Background(), env.identity)
|
||||
|
||||
// test and validations
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
socialConnector.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||
defaultTime := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenExpiry time.Time
|
||||
idTokenExpiry time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: time.Time{},
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
usr *login.UserAuth
|
||||
expectedTokenRefreshFlag bool
|
||||
expectedTokenDuration time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should not need token refresh when token has no expiration date",
|
||||
usr: &login.UserAuth{},
|
||||
expectedTokenRefreshFlag: false,
|
||||
expectedTokenDuration: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
|
||||
usr: &login.UserAuth{
|
||||
OAuthIdToken: "invalid_jwt_format",
|
||||
},
|
||||
expectedTokenRefreshFlag: false,
|
||||
expectedTokenDuration: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should flag token refresh with id token is expired",
|
||||
usr: &login.UserAuth{
|
||||
OAuthIdToken: EXPIRED_JWT,
|
||||
},
|
||||
expectedTokenRefreshFlag: true,
|
||||
expectedTokenDuration: time.Second,
|
||||
},
|
||||
{
|
||||
name: "should flag token refresh when expiry date is zero",
|
||||
usr: &login.UserAuth{
|
||||
OAuthExpiry: time.Unix(0, 0),
|
||||
},
|
||||
expectedTokenRefreshFlag: true,
|
||||
expectedTokenDuration: time.Second,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
|
||||
|
||||
assert.NotNil(t, token)
|
||||
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
|
||||
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
|
||||
timeNow := time.Now()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "oauth_access_token",
|
||||
RefreshToken: "refresh_token_found",
|
||||
Expiry: timeNow,
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
type environment struct {
|
||||
authInfoService *authinfotest.FakeService
|
||||
cache *localcache.CacheService
|
||||
socialConnector *socialtest.MockSocialConnector
|
||||
socialService *socialtest.FakeSocialService
|
||||
|
||||
service *Service
|
||||
}
|
||||
tests := []struct {
|
||||
desc string
|
||||
expectedErr error
|
||||
expectedToken *oauth2.Token
|
||||
usr *login.UserAuth
|
||||
setup func(env *environment)
|
||||
}{
|
||||
{
|
||||
desc: "should find and retrieve token from cache",
|
||||
usr: &login.UserAuth{
|
||||
UserId: int64(1234),
|
||||
OAuthAccessToken: "new_access_token",
|
||||
OAuthExpiry: timeNow,
|
||||
},
|
||||
setup: func(env *environment) {
|
||||
env.cache.Set("token-check-1234", token, 1*time.Minute)
|
||||
},
|
||||
expectedToken: &oauth2.Token{
|
||||
AccessToken: "new_access_token",
|
||||
Expiry: timeNow,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
|
||||
usr: &login.UserAuth{
|
||||
UserId: int64(1234),
|
||||
AuthModule: login.SAMLAuthModule,
|
||||
},
|
||||
expectedErr: ErrNotAnOAuthProvider,
|
||||
},
|
||||
{
|
||||
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
|
||||
usr: &login.UserAuth{
|
||||
UserId: int64(1234),
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
},
|
||||
expectedErr: ErrNoRefreshTokenFound,
|
||||
},
|
||||
{
|
||||
desc: "should not refresh token if the token is not expired",
|
||||
usr: &login.UserAuth{
|
||||
UserId: int64(1234),
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: timeNow.Add(time.Hour),
|
||||
OAuthTokenType: "Bearer",
|
||||
},
|
||||
expectedToken: &oauth2.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: timeNow.Add(time.Hour),
|
||||
TokenType: "Bearer",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should update saved token if the user auth has new access/refresh tokens",
|
||||
usr: &login.UserAuth{
|
||||
UserId: int64(1234),
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
OAuthAccessToken: "new_oauth_access_token",
|
||||
OAuthRefreshToken: "new_refresh_token_found",
|
||||
OAuthExpiry: timeNow,
|
||||
},
|
||||
expectedToken: &oauth2.Token{
|
||||
AccessToken: "oauth_access_token",
|
||||
RefreshToken: "refresh_token_found",
|
||||
Expiry: timeNow,
|
||||
TokenType: "Bearer",
|
||||
},
|
||||
setup: func(env *environment) {
|
||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
socialConnector := &socialtest.MockSocialConnector{}
|
||||
|
||||
env := environment{
|
||||
authInfoService: &authinfotest.FakeService{},
|
||||
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
socialConnector: socialConnector,
|
||||
socialService: &socialtest.FakeSocialService{
|
||||
ExpectedConnector: socialConnector,
|
||||
},
|
||||
}
|
||||
|
||||
if tt.setup != nil {
|
||||
tt.setup(&env)
|
||||
}
|
||||
|
||||
env.service = &Service{
|
||||
AuthInfoService: env.authInfoService,
|
||||
Cfg: setting.NewCfg(),
|
||||
cache: env.cache,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
SocialService: env.socialService,
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||
}
|
||||
|
||||
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
|
||||
|
||||
if tt.expectedToken != nil {
|
||||
assert.Equal(t, tt.expectedToken, token)
|
||||
}
|
||||
assert.ErrorIs(t, tt.expectedErr, err)
|
||||
socialConnector.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ type MockOauthTokenService struct {
|
||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) error
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||
@ -46,7 +46,7 @@ func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *login.UserAuth) error {
|
||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
|
||||
if m.TryTokenRefreshFunc != nil {
|
||||
return m.TryTokenRefreshFunc(ctx, usr)
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.Use
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *Service) TryTokenRefresh(context.Context, *login.UserAuth) error {
|
||||
func (s *Service) TryTokenRefresh(context.Context, identity.Requester) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
146
pkg/services/oauthtoken/oauthtokentest/service_mock.go
Normal file
146
pkg/services/oauthtoken/oauthtokentest/service_mock.go
Normal file
@ -0,0 +1,146 @@
|
||||
// Code generated by mockery v2.40.1. DO NOT EDIT.
|
||||
|
||||
package oauthtokentest
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
identity "github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
datasources "github.com/grafana/grafana/pkg/services/datasources"
|
||||
|
||||
login "github.com/grafana/grafana/pkg/services/login"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// MockService is an autogenerated mock type for the OAuthTokenService type
|
||||
type MockService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// GetCurrentOAuthToken provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockService) GetCurrentOAuthToken(_a0 context.Context, _a1 identity.Requester) *oauth2.Token {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetCurrentOAuthToken")
|
||||
}
|
||||
|
||||
var r0 *oauth2.Token
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *oauth2.Token); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*oauth2.Token)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// HasOAuthEntry provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockService) HasOAuthEntry(_a0 context.Context, _a1 identity.Requester) (*login.UserAuth, bool, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for HasOAuthEntry")
|
||||
}
|
||||
|
||||
var r0 *login.UserAuth
|
||||
var r1 bool
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) (*login.UserAuth, bool, error)); ok {
|
||||
return rf(_a0, _a1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) *login.UserAuth); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*login.UserAuth)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, identity.Requester) bool); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Get(1).(bool)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(2).(func(context.Context, identity.Requester) error); ok {
|
||||
r2 = rf(_a0, _a1)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// InvalidateOAuthTokens provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockService) InvalidateOAuthTokens(_a0 context.Context, _a1 *login.UserAuth) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for InvalidateOAuthTokens")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *login.UserAuth) error); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsOAuthPassThruEnabled provides a mock function with given fields: _a0
|
||||
func (_m *MockService) IsOAuthPassThruEnabled(_a0 *datasources.DataSource) bool {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IsOAuthPassThruEnabled")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(*datasources.DataSource) bool); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// TryTokenRefresh provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockService) TryTokenRefresh(_a0 context.Context, _a1 identity.Requester) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for TryTokenRefresh")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) error); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockService {
|
||||
mock := &MockService{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
Loading…
Reference in New Issue
Block a user