mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AuthN: move oauth token hook into session client (#76688)
* Move rotate logic into its own function * Move oauth token sync to session client * Add user to the local cache if refresh tokens are not enabled for the provider so we can skip the check in other requests
This commit is contained in:
parent
8b16f2aca8
commit
455cede699
@ -90,7 +90,7 @@ func ProvideService(
|
|||||||
s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService))
|
s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService))
|
||||||
|
|
||||||
if cfg.LoginCookieName != "" {
|
if cfg.LoginCookieName != "" {
|
||||||
s.RegisterClient(clients.ProvideSession(cfg, sessionService, features))
|
s.RegisterClient(clients.ProvideSession(cfg, features, sessionService, oauthTokenService, socialService))
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyClients []authn.ProxyClient
|
var proxyClients []authn.ProxyClient
|
||||||
@ -157,14 +157,9 @@ func ProvideService(
|
|||||||
s.RegisterPostAuthHook(userSyncService.SyncUserHook, 10)
|
s.RegisterPostAuthHook(userSyncService.SyncUserHook, 10)
|
||||||
s.RegisterPostAuthHook(userSyncService.EnableUserHook, 20)
|
s.RegisterPostAuthHook(userSyncService.EnableUserHook, 20)
|
||||||
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgRolesHook, 30)
|
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgRolesHook, 30)
|
||||||
s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120)
|
|
||||||
|
|
||||||
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
|
||||||
s.RegisterPostAuthHook(sync.ProvideOAuthTokenSync(oauthTokenService, sessionService, socialService).SyncOauthTokenHook, 60)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.RegisterPostAuthHook(userSyncService.FetchSyncedUserHook, 100)
|
s.RegisterPostAuthHook(userSyncService.FetchSyncedUserHook, 100)
|
||||||
s.RegisterPostAuthHook(sync.ProvidePermissionsSync(accessControlService).SyncPermissionsHook, 110)
|
s.RegisterPostAuthHook(sync.ProvidePermissionsSync(accessControlService).SyncPermissionsHook, 110)
|
||||||
|
s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120)
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -1,174 +0,0 @@
|
|||||||
package sync
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
|
||||||
return &OAuthTokenSync{
|
|
||||||
log.New("oauth_token.sync"),
|
|
||||||
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
|
||||||
service,
|
|
||||||
sessionService,
|
|
||||||
socialService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type OAuthTokenSync struct {
|
|
||||||
log log.Logger
|
|
||||||
cache *localcache.CacheService
|
|
||||||
service oauthtoken.OAuthTokenService
|
|
||||||
sessionService auth.UserTokenService
|
|
||||||
socialService social.Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
|
||||||
namespace, id := identity.NamespacedID()
|
|
||||||
// only perform oauth token check if identity is a user
|
|
||||||
if namespace != authn.NamespaceUser {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// not authenticated through session tokens, so we can skip this hook
|
|
||||||
if identity.SessionToken == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we recently have performed this it would be cached, so we can skip the hook
|
|
||||||
if _, ok := s.cache.Get(identity.ID); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
|
||||||
// user is not authenticated through oauth so skip further checks
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
|
||||||
if err != nil {
|
|
||||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// token has no expire time configured, so we don't have to refresh it
|
|
||||||
if token.OAuthExpiry.IsZero() {
|
|
||||||
// cache the token check, so we don't perform it on every request
|
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the token's auth provider (f.e. azuread)
|
|
||||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
|
||||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
|
||||||
if currentOAuthInfo == nil {
|
|
||||||
s.log.Warn("OAuth provider not found", "provider", provider)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
|
||||||
if !currentOAuthInfo.UseRefreshToken {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
|
||||||
|
|
||||||
hasIdTokenExpired := false
|
|
||||||
idTokenExpires := time.Time{}
|
|
||||||
|
|
||||||
if !idTokenExpiry.IsZero() {
|
|
||||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
|
||||||
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
|
||||||
}
|
|
||||||
// token has not expired, so we don't have to refresh it
|
|
||||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
|
||||||
// cache the token check, so we don't perform it on every request
|
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
|
||||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
|
||||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
|
||||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
|
||||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
|
||||||
|
|
||||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
|
||||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
|
||||||
return maxOAuthTokenCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
min := func(a, b time.Duration) time.Duration {
|
|
||||||
if a <= b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
|
||||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
|
||||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
|
||||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
|
||||||
if token.OAuthIdToken == "" {
|
|
||||||
return time.Time{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
|
||||||
if err != nil {
|
|
||||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Claims struct {
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
var claims Claims
|
|
||||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
|
||||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Unix(claims.Exp, 0), nil
|
|
||||||
}
|
|
@ -1,258 +0,0 @@
|
|||||||
package sync
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
|
||||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
desc string
|
|
||||||
identity *authn.Identity
|
|
||||||
oauthInfo *social.OAuthInfo
|
|
||||||
|
|
||||||
expectedHasEntryToken *login.UserAuth
|
|
||||||
expectHasEntryCalled bool
|
|
||||||
|
|
||||||
expectedTryRefreshErr error
|
|
||||||
expectTryRefreshTokenCalled bool
|
|
||||||
|
|
||||||
expectRevokeTokenCalled bool
|
|
||||||
expectInvalidateOauthTokensCalled bool
|
|
||||||
|
|
||||||
expectedErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []testCase{
|
|
||||||
{
|
|
||||||
desc: "should skip sync when identity is not a user",
|
|
||||||
identity: &authn.Identity{ID: "service-account:1"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
|
||||||
identity: &authn.Identity{ID: "user:1"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync for when access token don't have expire time",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when access token has no expired yet",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip sync when access token has no expired yet",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should refresh access token when is has expired",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectTryRefreshTokenCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectedTryRefreshErr: errors.New("some err"),
|
|
||||||
expectTryRefreshTokenCalled: true,
|
|
||||||
expectInvalidateOauthTokensCalled: true,
|
|
||||||
expectRevokeTokenCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
||||||
expectedErr: authn.ErrExpiredAccessToken,
|
|
||||||
}, {
|
|
||||||
desc: "should skip sync when use_refresh_token is disabled",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectTryRefreshTokenCalled: false,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
||||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should refresh access token when ID token has expired",
|
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
|
||||||
expectHasEntryCalled: true,
|
|
||||||
expectTryRefreshTokenCalled: true,
|
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.desc, func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
hasEntryCalled bool
|
|
||||||
tryRefreshCalled bool
|
|
||||||
invalidateTokensCalled bool
|
|
||||||
revokeTokenCalled bool
|
|
||||||
)
|
|
||||||
|
|
||||||
service := &oauthtokentest.MockOauthTokenService{
|
|
||||||
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
|
||||||
hasEntryCalled = true
|
|
||||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
|
||||||
},
|
|
||||||
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
|
||||||
invalidateTokensCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
|
||||||
tryRefreshCalled = true
|
|
||||||
return tt.expectedTryRefreshErr
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionService := &authtest.FakeUserAuthTokenService{
|
|
||||||
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
|
|
||||||
revokeTokenCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.oauthInfo == nil {
|
|
||||||
tt.oauthInfo = &social.OAuthInfo{
|
|
||||||
UseRefreshToken: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
socialService := &socialtest.FakeSocialService{
|
|
||||||
ExpectedAuthInfoProvider: tt.oauthInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
sync := &OAuthTokenSync{
|
|
||||||
log: log.NewNopLogger(),
|
|
||||||
cache: localcache.New(0, 0),
|
|
||||||
service: service,
|
|
||||||
sessionService: sessionService,
|
|
||||||
socialService: socialService,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
|
||||||
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
|
||||||
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
|
||||||
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
|
||||||
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
|
||||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
|
||||||
type Header struct {
|
|
||||||
Kid string `json:"kid"`
|
|
||||||
Alg string `json:"alg"`
|
|
||||||
}
|
|
||||||
type Payload struct {
|
|
||||||
Iss string `json:"iss"`
|
|
||||||
Sub string `json:"sub"`
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
u := expiryDate.UTC().Unix()
|
|
||||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
fakeSignature := []byte("6ICJm")
|
|
||||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
|
||||||
defaultTime := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accessTokenExpiry time.Time
|
|
||||||
idTokenExpiry time.Time
|
|
||||||
want time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
|
||||||
accessTokenExpiry: time.Time{},
|
|
||||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
|
||||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: time.Time{},
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
|
||||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
|
||||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
|
||||||
want: maxOAuthTokenCacheTTL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -394,4 +394,5 @@ func syncSignedInUserToIdentity(usr *user.SignedInUser, identity *authn.Identity
|
|||||||
identity.LastSeenAt = usr.LastSeenAt
|
identity.LastSeenAt = usr.LastSeenAt
|
||||||
identity.IsDisabled = usr.IsDisabled
|
identity.IsDisabled = usr.IsDisabled
|
||||||
identity.IsGrafanaAdmin = &usr.IsGrafanaAdmin
|
identity.IsGrafanaAdmin = &usr.IsGrafanaAdmin
|
||||||
|
identity.AuthenticatedBy = usr.AuthenticatedBy
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,23 @@ package clients
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/infra/network"
|
"github.com/grafana/grafana/pkg/infra/network"
|
||||||
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
@ -18,21 +27,31 @@ import (
|
|||||||
var _ authn.HookClient = new(Session)
|
var _ authn.HookClient = new(Session)
|
||||||
var _ authn.ContextAwareClient = new(Session)
|
var _ authn.ContextAwareClient = new(Session)
|
||||||
|
|
||||||
func ProvideSession(cfg *setting.Cfg, sessionService auth.UserTokenService,
|
func ProvideSession(
|
||||||
features *featuremgmt.FeatureManager) *Session {
|
cfg *setting.Cfg, features *featuremgmt.FeatureManager, sessionService auth.UserTokenService,
|
||||||
|
oauthTokenService oauthtoken.OAuthTokenService, socialService social.Service,
|
||||||
|
) *Session {
|
||||||
return &Session{
|
return &Session{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
features: features,
|
features: features,
|
||||||
sessionService: sessionService,
|
sessionService: sessionService,
|
||||||
log: log.New(authn.ClientSession),
|
oauthTokenService: oauthTokenService,
|
||||||
|
socialService: socialService,
|
||||||
|
log: log.New(authn.ClientSession),
|
||||||
|
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
cfg *setting.Cfg
|
log log.Logger
|
||||||
features *featuremgmt.FeatureManager
|
cfg *setting.Cfg
|
||||||
sessionService auth.UserTokenService
|
features *featuremgmt.FeatureManager
|
||||||
log log.Logger
|
|
||||||
|
socialService social.Service
|
||||||
|
sessionService auth.UserTokenService
|
||||||
|
oauthTokenService oauthtoken.OAuthTokenService
|
||||||
|
|
||||||
|
cache *localcache.CacheService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Name() string {
|
func (s *Session) Name() string {
|
||||||
@ -88,7 +107,19 @@ func (s *Session) Priority() uint {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
|
func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
|
||||||
if identity.SessionToken == nil || s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) {
|
if identity.SessionToken == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.rotateTokenHook(ctx, identity, r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.syncOAuthTokenHook(ctx, identity, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) rotateTokenHook(ctx context.Context, identity *authn.Identity, r *authn.Request) error {
|
||||||
|
if s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,3 +154,143 @@ func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.R
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) syncOAuthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||||
|
if !s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace, id := identity.NamespacedID()
|
||||||
|
// only perform oauth token check if identity is a user
|
||||||
|
if namespace != authn.NamespaceUser {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we recently have performed this it would be cached, so we can skip the hook
|
||||||
|
if _, ok := s.cache.Get(identity.ID); ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, exists, _ := s.oauthTokenService.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||||
|
// user is not authenticated through oauth so skip further checks
|
||||||
|
if !exists {
|
||||||
|
// if user is not authenticated through oauth we can skip this check by adding the id to the cache
|
||||||
|
s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the token's auth provider (f.e. azuread)
|
||||||
|
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
|
||||||
|
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
|
||||||
|
if currentOAuthInfo == nil {
|
||||||
|
s.log.Warn("OAuth provider not found", "provider", provider)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if refresh token handling is disabled for this provider, we can skip the hook
|
||||||
|
if !currentOAuthInfo.UseRefreshToken {
|
||||||
|
// refresh token is not configured for provider so we can skip this check by adding the id to the cache
|
||||||
|
s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||||
|
if err != nil {
|
||||||
|
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// token has no expire time configured, so we don't have to refresh it
|
||||||
|
if token.OAuthExpiry.IsZero() {
|
||||||
|
// cache the token check, so we don't perform it on every request
|
||||||
|
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||||
|
|
||||||
|
hasIdTokenExpired := false
|
||||||
|
idTokenExpires := time.Time{}
|
||||||
|
|
||||||
|
if !idTokenExpiry.IsZero() {
|
||||||
|
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||||
|
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// token has not expired, so we don't have to refresh it
|
||||||
|
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||||
|
// cache the token check, so we don't perform it on every request
|
||||||
|
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||||
|
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.oauthTokenService.TryTokenRefresh(updateCtx, token); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||||
|
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.oauthTokenService.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||||
|
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||||
|
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||||
|
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||||
|
return maxOAuthTokenCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
min := func(a, b time.Duration) time.Duration {
|
||||||
|
if a <= b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||||
|
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||||
|
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||||
|
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||||
|
if token.OAuthIdToken == "" {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
var claims Claims
|
||||||
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Unix(claims.Exp, 0), nil
|
||||||
|
}
|
||||||
|
@ -2,6 +2,10 @@ package clients
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
@ -10,11 +14,16 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
@ -29,7 +38,7 @@ func TestSession_Test(t *testing.T) {
|
|||||||
cfg := setting.NewCfg()
|
cfg := setting.NewCfg()
|
||||||
cfg.LoginCookieName = ""
|
cfg.LoginCookieName = ""
|
||||||
cfg.LoginMaxLifetime = 20 * time.Second
|
cfg.LoginMaxLifetime = 20 * time.Second
|
||||||
s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{}, featuremgmt.WithFeatures())
|
s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{}, nil, nil)
|
||||||
|
|
||||||
disabled := s.Test(context.Background(), &authn.Request{HTTPRequest: validHTTPReq})
|
disabled := s.Test(context.Background(), &authn.Request{HTTPRequest: validHTTPReq})
|
||||||
assert.False(t, disabled)
|
assert.False(t, disabled)
|
||||||
@ -145,7 +154,7 @@ func TestSession_Authenticate(t *testing.T) {
|
|||||||
cfg.LoginCookieName = cookieName
|
cfg.LoginCookieName = cookieName
|
||||||
cfg.TokenRotationIntervalMinutes = 10
|
cfg.TokenRotationIntervalMinutes = 10
|
||||||
cfg.LoginMaxLifetime = 20 * time.Second
|
cfg.LoginMaxLifetime = 20 * time.Second
|
||||||
s := ProvideSession(cfg, tt.fields.sessionService, tt.fields.features)
|
s := ProvideSession(cfg, tt.fields.features, tt.fields.sessionService, nil, nil)
|
||||||
|
|
||||||
got, err := s.Authenticate(context.Background(), tt.args.r)
|
got, err := s.Authenticate(context.Background(), tt.args.r)
|
||||||
require.True(t, (err != nil) == tt.wantErr, err)
|
require.True(t, (err != nil) == tt.wantErr, err)
|
||||||
@ -175,17 +184,17 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) {
|
|||||||
f.Status = statusCode
|
f.Status = statusCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSession_Hook(t *testing.T) {
|
func TestSession_RotateSessionHook(t *testing.T) {
|
||||||
t.Run("should rotate token", func(t *testing.T) {
|
t.Run("should rotate token", func(t *testing.T) {
|
||||||
cfg := setting.NewCfg()
|
cfg := setting.NewCfg()
|
||||||
cfg.LoginCookieName = "grafana-session"
|
cfg.LoginCookieName = "grafana-session"
|
||||||
cfg.LoginMaxLifetime = 20 * time.Second
|
cfg.LoginMaxLifetime = 20 * time.Second
|
||||||
s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{
|
s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{
|
||||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
TryRotateTokenProvider: func(_ context.Context, token *auth.UserToken, _ net.IP, _ string) (bool, *auth.UserToken, error) {
|
||||||
token.UnhashedToken = "new-token"
|
token.UnhashedToken = "new-token"
|
||||||
return true, token, nil
|
return true, token, nil
|
||||||
},
|
},
|
||||||
}, featuremgmt.WithFeatures())
|
}, nil, nil)
|
||||||
|
|
||||||
sampleID := &authn.Identity{
|
sampleID := &authn.Identity{
|
||||||
SessionToken: &auth.UserToken{
|
SessionToken: &auth.UserToken{
|
||||||
@ -206,7 +215,7 @@ func TestSession_Hook(t *testing.T) {
|
|||||||
Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter),
|
Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.Hook(context.Background(), sampleID, resp)
|
err := s.rotateTokenHook(context.Background(), sampleID, resp)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp.Resp.WriteHeader(201)
|
resp.Resp.WriteHeader(201)
|
||||||
@ -219,7 +228,7 @@ func TestSession_Hook(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should not rotate token with feature flag", func(t *testing.T) {
|
t.Run("should not rotate token with feature flag", func(t *testing.T) {
|
||||||
s := ProvideSession(setting.NewCfg(), nil, featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation))
|
s := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation), nil, nil, nil)
|
||||||
|
|
||||||
req := &authn.Request{}
|
req := &authn.Request{}
|
||||||
identity := &authn.Identity{}
|
identity := &authn.Identity{}
|
||||||
@ -227,3 +236,226 @@ func TestSession_Hook(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSession_SyncOAuthTokenHook(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
identity *authn.Identity
|
||||||
|
oauthInfo *social.OAuthInfo
|
||||||
|
|
||||||
|
expectedHasEntryToken *login.UserAuth
|
||||||
|
expectHasEntryCalled bool
|
||||||
|
|
||||||
|
expectedTryRefreshErr error
|
||||||
|
expectTryRefreshTokenCalled bool
|
||||||
|
|
||||||
|
expectRevokeTokenCalled bool
|
||||||
|
expectInvalidateOauthTokensCalled bool
|
||||||
|
|
||||||
|
expectedErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should skip sync when identity is not a user",
|
||||||
|
identity: &authn.Identity{ID: "service-account:1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync for when access token don't have expire time",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when access token has no expired yet",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when access token has no expired yet",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should refresh access token when is has expired",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedTryRefreshErr: errors.New("some err"),
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectInvalidateOauthTokensCalled: true,
|
||||||
|
expectRevokeTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
expectedErr: authn.ErrExpiredAccessToken,
|
||||||
|
}, {
|
||||||
|
desc: "should skip sync when use_refresh_token is disabled",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: false,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should refresh access token when ID token has expired",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
hasEntryCalled bool
|
||||||
|
tryRefreshCalled bool
|
||||||
|
invalidateTokensCalled bool
|
||||||
|
revokeTokenCalled bool
|
||||||
|
)
|
||||||
|
|
||||||
|
oauthTokenService := &oauthtokentest.MockOauthTokenService{
|
||||||
|
HasOAuthEntryFunc: func(_ context.Context, _ identity.Requester) (*login.UserAuth, bool, error) {
|
||||||
|
hasEntryCalled = true
|
||||||
|
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||||
|
},
|
||||||
|
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||||
|
invalidateTokensCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
TryTokenRefreshFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||||
|
tryRefreshCalled = true
|
||||||
|
return tt.expectedTryRefreshErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionService := &authtest.FakeUserAuthTokenService{
|
||||||
|
RevokeTokenProvider: func(_ context.Context, _ *auth.UserToken, _ bool) error {
|
||||||
|
revokeTokenCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.oauthInfo == nil {
|
||||||
|
tt.oauthInfo = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
socialService := &socialtest.FakeSocialService{
|
||||||
|
ExpectedAuthInfoProvider: tt.oauthInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck), sessionService, oauthTokenService, socialService)
|
||||||
|
|
||||||
|
err := client.syncOAuthTokenHook(context.Background(), tt.identity, nil)
|
||||||
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
||||||
|
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
||||||
|
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
||||||
|
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeIDToken is used to create sa fake invalid token to verify expiry logic
|
||||||
|
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||||
|
type Header struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
}
|
||||||
|
type Payload struct {
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
u := expiryDate.UTC().Unix()
|
||||||
|
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
fakeSignature := []byte("6ICJm")
|
||||||
|
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOAuthTokenCacheTTL(t *testing.T) {
|
||||||
|
defaultTime := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessTokenExpiry time.Time
|
||||||
|
idTokenExpiry time.Time
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user