mirror of
https://github.com/grafana/grafana.git
synced 2024-11-26 02:40:26 -06:00
AuthN: move oauth token hook into session client (#76688)
* Move rotate logic into its own function * Move oauth token sync to session client * Add user to the local cache if refresh tokens are not enabled for the provider so we can skip the check in other requests
This commit is contained in:
parent
8b16f2aca8
commit
455cede699
@ -90,7 +90,7 @@ func ProvideService(
|
||||
s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService))
|
||||
|
||||
if cfg.LoginCookieName != "" {
|
||||
s.RegisterClient(clients.ProvideSession(cfg, sessionService, features))
|
||||
s.RegisterClient(clients.ProvideSession(cfg, features, sessionService, oauthTokenService, socialService))
|
||||
}
|
||||
|
||||
var proxyClients []authn.ProxyClient
|
||||
@ -157,14 +157,9 @@ func ProvideService(
|
||||
s.RegisterPostAuthHook(userSyncService.SyncUserHook, 10)
|
||||
s.RegisterPostAuthHook(userSyncService.EnableUserHook, 20)
|
||||
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgRolesHook, 30)
|
||||
s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120)
|
||||
|
||||
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
||||
s.RegisterPostAuthHook(sync.ProvideOAuthTokenSync(oauthTokenService, sessionService, socialService).SyncOauthTokenHook, 60)
|
||||
}
|
||||
|
||||
s.RegisterPostAuthHook(userSyncService.FetchSyncedUserHook, 100)
|
||||
s.RegisterPostAuthHook(sync.ProvidePermissionsSync(accessControlService).SyncPermissionsHook, 110)
|
||||
s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120)
|
||||
|
||||
return s
|
||||
}
|
||||
|
@ -1,174 +0,0 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||
return &OAuthTokenSync{
|
||||
log.New("oauth_token.sync"),
|
||||
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
service,
|
||||
sessionService,
|
||||
socialService,
|
||||
}
|
||||
}
|
||||
|
||||
type OAuthTokenSync struct {
|
||||
log log.Logger
|
||||
cache *localcache.CacheService
|
||||
service oauthtoken.OAuthTokenService
|
||||
sessionService auth.UserTokenService
|
||||
socialService social.Service
|
||||
}
|
||||
|
||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||
namespace, id := identity.NamespacedID()
|
||||
// only perform oauth token check if identity is a user
|
||||
if namespace != authn.NamespaceUser {
|
||||
return nil
|
||||
}
|
||||
|
||||
// not authenticated through session tokens, so we can skip this hook
|
||||
if identity.SessionToken == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if we recently have performed this it would be cached, so we can skip the hook
|
||||
if _, ok := s.cache.Get(identity.ID); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
||||
if currentOAuthInfo == nil {
|
||||
s.log.Warn("OAuth provider not found", "provider", provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
||||
if !currentOAuthInfo.UseRefreshToken {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
|
||||
hasIdTokenExpired := false
|
||||
idTokenExpires := time.Time{}
|
||||
|
||||
if !idTokenExpiry.IsZero() {
|
||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||
}
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return maxOAuthTokenCacheTTL
|
||||
}
|
||||
|
||||
min := func(a, b time.Duration) time.Duration {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||
if token.OAuthIdToken == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
var claims Claims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
@ -1,258 +0,0 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
)
|
||||
|
||||
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
identity *authn.Identity
|
||||
oauthInfo *social.OAuthInfo
|
||||
|
||||
expectedHasEntryToken *login.UserAuth
|
||||
expectHasEntryCalled bool
|
||||
|
||||
expectedTryRefreshErr error
|
||||
expectTryRefreshTokenCalled bool
|
||||
|
||||
expectRevokeTokenCalled bool
|
||||
expectInvalidateOauthTokensCalled bool
|
||||
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should skip sync when identity is not a user",
|
||||
identity: &authn.Identity{ID: "service-account:1"},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync for when access token don't have expire time",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when is has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedTryRefreshErr: errors.New("some err"),
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectInvalidateOauthTokensCalled: true,
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
expectedErr: authn.ErrExpiredAccessToken,
|
||||
}, {
|
||||
desc: "should skip sync when use_refresh_token is disabled",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: false,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when ID token has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var (
|
||||
hasEntryCalled bool
|
||||
tryRefreshCalled bool
|
||||
invalidateTokensCalled bool
|
||||
revokeTokenCalled bool
|
||||
)
|
||||
|
||||
service := &oauthtokentest.MockOauthTokenService{
|
||||
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
||||
hasEntryCalled = true
|
||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||
},
|
||||
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
||||
invalidateTokensCalled = true
|
||||
return nil
|
||||
},
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
||||
tryRefreshCalled = true
|
||||
return tt.expectedTryRefreshErr
|
||||
},
|
||||
}
|
||||
|
||||
sessionService := &authtest.FakeUserAuthTokenService{
|
||||
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
|
||||
revokeTokenCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if tt.oauthInfo == nil {
|
||||
tt.oauthInfo = &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
}
|
||||
}
|
||||
|
||||
socialService := &socialtest.FakeSocialService{
|
||||
ExpectedAuthInfoProvider: tt.oauthInfo,
|
||||
}
|
||||
|
||||
sync := &OAuthTokenSync{
|
||||
log: log.NewNopLogger(),
|
||||
cache: localcache.New(0, 0),
|
||||
service: service,
|
||||
sessionService: sessionService,
|
||||
socialService: socialService,
|
||||
}
|
||||
|
||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
||||
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
||||
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
||||
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||
type Header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
type Payload struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||
require.NoError(t, err)
|
||||
u := expiryDate.UTC().Unix()
|
||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeSignature := []byte("6ICJm")
|
||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||
defaultTime := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenExpiry time.Time
|
||||
idTokenExpiry time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: time.Time{},
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
@ -394,4 +394,5 @@ func syncSignedInUserToIdentity(usr *user.SignedInUser, identity *authn.Identity
|
||||
identity.LastSeenAt = usr.LastSeenAt
|
||||
identity.IsDisabled = usr.IsDisabled
|
||||
identity.IsGrafanaAdmin = &usr.IsGrafanaAdmin
|
||||
identity.AuthenticatedBy = usr.AuthenticatedBy
|
||||
}
|
||||
|
@ -3,14 +3,23 @@ package clients
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/network"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
@ -18,21 +27,31 @@ import (
|
||||
var _ authn.HookClient = new(Session)
|
||||
var _ authn.ContextAwareClient = new(Session)
|
||||
|
||||
func ProvideSession(cfg *setting.Cfg, sessionService auth.UserTokenService,
|
||||
features *featuremgmt.FeatureManager) *Session {
|
||||
func ProvideSession(
|
||||
cfg *setting.Cfg, features *featuremgmt.FeatureManager, sessionService auth.UserTokenService,
|
||||
oauthTokenService oauthtoken.OAuthTokenService, socialService social.Service,
|
||||
) *Session {
|
||||
return &Session{
|
||||
cfg: cfg,
|
||||
features: features,
|
||||
sessionService: sessionService,
|
||||
log: log.New(authn.ClientSession),
|
||||
cfg: cfg,
|
||||
features: features,
|
||||
sessionService: sessionService,
|
||||
oauthTokenService: oauthTokenService,
|
||||
socialService: socialService,
|
||||
log: log.New(authn.ClientSession),
|
||||
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
cfg *setting.Cfg
|
||||
features *featuremgmt.FeatureManager
|
||||
sessionService auth.UserTokenService
|
||||
log log.Logger
|
||||
log log.Logger
|
||||
cfg *setting.Cfg
|
||||
features *featuremgmt.FeatureManager
|
||||
|
||||
socialService social.Service
|
||||
sessionService auth.UserTokenService
|
||||
oauthTokenService oauthtoken.OAuthTokenService
|
||||
|
||||
cache *localcache.CacheService
|
||||
}
|
||||
|
||||
func (s *Session) Name() string {
|
||||
@ -88,7 +107,19 @@ func (s *Session) Priority() uint {
|
||||
}
|
||||
|
||||
func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
|
||||
if identity.SessionToken == nil || s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) {
|
||||
if identity.SessionToken == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.rotateTokenHook(ctx, identity, r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.syncOAuthTokenHook(ctx, identity, r)
|
||||
}
|
||||
|
||||
func (s *Session) rotateTokenHook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
|
||||
if s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -123,3 +154,143 @@ func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.R
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) syncOAuthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||
if !s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
||||
return nil
|
||||
}
|
||||
|
||||
namespace, id := identity.NamespacedID()
|
||||
// only perform oauth token check if identity is a user
|
||||
if namespace != authn.NamespaceUser {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if we recently have performed this it would be cached, so we can skip the hook
|
||||
if _, ok := s.cache.Get(identity.ID); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, exists, _ := s.oauthTokenService.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists {
|
||||
// if user is not authenticated through oauth we can skip this check by adding the id to the cache
|
||||
s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
||||
if currentOAuthInfo == nil {
|
||||
s.log.Warn("OAuth provider not found", "provider", provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
||||
if !currentOAuthInfo.UseRefreshToken {
|
||||
// refresh token is not configured for provider so we can skip this check by adding the id to the cache
|
||||
s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL)
|
||||
return nil
|
||||
}
|
||||
|
||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
|
||||
hasIdTokenExpired := false
|
||||
idTokenExpires := time.Time{}
|
||||
|
||||
if !idTokenExpiry.IsZero() {
|
||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||
}
|
||||
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.oauthTokenService.TryTokenRefresh(updateCtx, token); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.oauthTokenService.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return maxOAuthTokenCacheTTL
|
||||
}
|
||||
|
||||
min := func(a, b time.Duration) time.Duration {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||
if token.OAuthIdToken == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
var claims Claims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
@ -2,6 +2,10 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
@ -10,11 +14,16 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
@ -29,7 +38,7 @@ func TestSession_Test(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
cfg.LoginCookieName = ""
|
||||
cfg.LoginMaxLifetime = 20 * time.Second
|
||||
s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{}, featuremgmt.WithFeatures())
|
||||
s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{}, nil, nil)
|
||||
|
||||
disabled := s.Test(context.Background(), &authn.Request{HTTPRequest: validHTTPReq})
|
||||
assert.False(t, disabled)
|
||||
@ -145,7 +154,7 @@ func TestSession_Authenticate(t *testing.T) {
|
||||
cfg.LoginCookieName = cookieName
|
||||
cfg.TokenRotationIntervalMinutes = 10
|
||||
cfg.LoginMaxLifetime = 20 * time.Second
|
||||
s := ProvideSession(cfg, tt.fields.sessionService, tt.fields.features)
|
||||
s := ProvideSession(cfg, tt.fields.features, tt.fields.sessionService, nil, nil)
|
||||
|
||||
got, err := s.Authenticate(context.Background(), tt.args.r)
|
||||
require.True(t, (err != nil) == tt.wantErr, err)
|
||||
@ -175,17 +184,17 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) {
|
||||
f.Status = statusCode
|
||||
}
|
||||
|
||||
func TestSession_Hook(t *testing.T) {
|
||||
func TestSession_RotateSessionHook(t *testing.T) {
|
||||
t.Run("should rotate token", func(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
cfg.LoginCookieName = "grafana-session"
|
||||
cfg.LoginMaxLifetime = 20 * time.Second
|
||||
s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(_ context.Context, token *auth.UserToken, _ net.IP, _ string) (bool, *auth.UserToken, error) {
|
||||
token.UnhashedToken = "new-token"
|
||||
return true, token, nil
|
||||
},
|
||||
}, featuremgmt.WithFeatures())
|
||||
}, nil, nil)
|
||||
|
||||
sampleID := &authn.Identity{
|
||||
SessionToken: &auth.UserToken{
|
||||
@ -206,7 +215,7 @@ func TestSession_Hook(t *testing.T) {
|
||||
Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter),
|
||||
}
|
||||
|
||||
err := s.Hook(context.Background(), sampleID, resp)
|
||||
err := s.rotateTokenHook(context.Background(), sampleID, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp.Resp.WriteHeader(201)
|
||||
@ -219,7 +228,7 @@ func TestSession_Hook(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should not rotate token with feature flag", func(t *testing.T) {
|
||||
s := ProvideSession(setting.NewCfg(), nil, featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation))
|
||||
s := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation), nil, nil, nil)
|
||||
|
||||
req := &authn.Request{}
|
||||
identity := &authn.Identity{}
|
||||
@ -227,3 +236,226 @@ func TestSession_Hook(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSession_SyncOAuthTokenHook(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
identity *authn.Identity
|
||||
oauthInfo *social.OAuthInfo
|
||||
|
||||
expectedHasEntryToken *login.UserAuth
|
||||
expectHasEntryCalled bool
|
||||
|
||||
expectedTryRefreshErr error
|
||||
expectTryRefreshTokenCalled bool
|
||||
|
||||
expectRevokeTokenCalled bool
|
||||
expectInvalidateOauthTokensCalled bool
|
||||
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should skip sync when identity is not a user",
|
||||
identity: &authn.Identity{ID: "service-account:1"},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync for when access token don't have expire time",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when is has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedTryRefreshErr: errors.New("some err"),
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectInvalidateOauthTokensCalled: true,
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
expectedErr: authn.ErrExpiredAccessToken,
|
||||
}, {
|
||||
desc: "should skip sync when use_refresh_token is disabled",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: false,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when ID token has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var (
|
||||
hasEntryCalled bool
|
||||
tryRefreshCalled bool
|
||||
invalidateTokensCalled bool
|
||||
revokeTokenCalled bool
|
||||
)
|
||||
|
||||
oauthTokenService := &oauthtokentest.MockOauthTokenService{
|
||||
HasOAuthEntryFunc: func(_ context.Context, _ identity.Requester) (*login.UserAuth, bool, error) {
|
||||
hasEntryCalled = true
|
||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||
},
|
||||
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||
invalidateTokensCalled = true
|
||||
return nil
|
||||
},
|
||||
TryTokenRefreshFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||
tryRefreshCalled = true
|
||||
return tt.expectedTryRefreshErr
|
||||
},
|
||||
}
|
||||
|
||||
sessionService := &authtest.FakeUserAuthTokenService{
|
||||
RevokeTokenProvider: func(_ context.Context, _ *auth.UserToken, _ bool) error {
|
||||
revokeTokenCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if tt.oauthInfo == nil {
|
||||
tt.oauthInfo = &social.OAuthInfo{
|
||||
UseRefreshToken: true,
|
||||
}
|
||||
}
|
||||
|
||||
socialService := &socialtest.FakeSocialService{
|
||||
ExpectedAuthInfoProvider: tt.oauthInfo,
|
||||
}
|
||||
|
||||
client := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck), sessionService, oauthTokenService, socialService)
|
||||
|
||||
err := client.syncOAuthTokenHook(context.Background(), tt.identity, nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
||||
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
||||
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
||||
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakeIDToken is used to create sa fake invalid token to verify expiry logic
|
||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||
type Header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
type Payload struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||
require.NoError(t, err)
|
||||
u := expiryDate.UTC().Unix()
|
||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeSignature := []byte("6ICJm")
|
||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||
}
|
||||
|
||||
func TestGetOAuthTokenCacheTTL(t *testing.T) {
|
||||
defaultTime := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenExpiry time.Time
|
||||
idTokenExpiry time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: time.Time{},
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user