From ed1c50233f219bf8af52191c70e3e651eb7834b8 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Fri, 20 Oct 2023 16:09:46 +0200 Subject: [PATCH] Revert "AuthN: move oauth token hook into session client" (#76882) Revert "AuthN: move oauth token hook into session client (#76688)" This reverts commit 455cede6992cfad57f1c5b1749c9454d1475dba6. --- pkg/services/authn/authnimpl/service.go | 9 +- .../authn/authnimpl/sync/oauth_token_sync.go | 174 ++++++++++++ .../authnimpl/sync/oauth_token_sync_test.go | 258 ++++++++++++++++++ .../authn/authnimpl/sync/user_sync.go | 1 - pkg/services/authn/clients/session.go | 193 +------------ pkg/services/authn/clients/session_test.go | 248 +---------------- 6 files changed, 458 insertions(+), 425 deletions(-) create mode 100644 pkg/services/authn/authnimpl/sync/oauth_token_sync.go create mode 100644 pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index 5b364343db7..4c50c97e9ff 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -90,7 +90,7 @@ func ProvideService( s.RegisterClient(clients.ProvideAPIKey(apikeyService, userService)) if cfg.LoginCookieName != "" { - s.RegisterClient(clients.ProvideSession(cfg, features, sessionService, oauthTokenService, socialService)) + s.RegisterClient(clients.ProvideSession(cfg, sessionService, features)) } var proxyClients []authn.ProxyClient @@ -157,9 +157,14 @@ func ProvideService( s.RegisterPostAuthHook(userSyncService.SyncUserHook, 10) s.RegisterPostAuthHook(userSyncService.EnableUserHook, 20) s.RegisterPostAuthHook(orgUserSyncService.SyncOrgRolesHook, 30) + s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120) + + if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { + s.RegisterPostAuthHook(sync.ProvideOAuthTokenSync(oauthTokenService, sessionService, socialService).SyncOauthTokenHook, 60) + } + s.RegisterPostAuthHook(userSyncService.FetchSyncedUserHook, 100) s.RegisterPostAuthHook(sync.ProvidePermissionsSync(accessControlService).SyncPermissionsHook, 110) - s.RegisterPostAuthHook(userSyncService.SyncLastSeenHook, 120) return s } diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go new file mode 100644 index 00000000000..cfe3978b5ce --- /dev/null +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go @@ -0,0 +1,174 @@ +package sync + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/go-jose/go-jose/v3/jwt" + + "github.com/grafana/grafana/pkg/infra/localcache" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken" + "github.com/grafana/grafana/pkg/services/user" +) + +func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync { + return &OAuthTokenSync{ + log.New("oauth_token.sync"), + localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute), + service, + sessionService, + socialService, + } +} + +type OAuthTokenSync struct { + log log.Logger + cache *localcache.CacheService + service oauthtoken.OAuthTokenService + sessionService auth.UserTokenService + socialService social.Service +} + +func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { + namespace, id := identity.NamespacedID() + // only perform oauth token check if identity is a user + if namespace != authn.NamespaceUser { + return nil + } + + // not authenticated through session tokens, so we can skip this hook + if identity.SessionToken == nil { + return nil + } + + // if we recently have performed this it would be cached, so we can skip the hook + if _, ok := s.cache.Get(identity.ID); ok { + return nil + } + + token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id}) + // user is not authenticated through oauth so skip further checks + if !exists { + return nil + } + + idTokenExpiry, err := getIDTokenExpiry(token) + if err != nil { + s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err) + } + + // token has no expire time configured, so we don't have to refresh it + if token.OAuthExpiry.IsZero() { + // cache the token check, so we don't perform it on every request + s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry)) + return nil + } + + // get the token's auth provider (f.e. azuread) + provider := strings.TrimPrefix(token.AuthModule, "oauth_") + currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider) + if currentOAuthInfo == nil { + s.log.Warn("OAuth provider not found", "provider", provider) + return nil + } + + // if refresh token handling is disabled for this provider, we can skip the hook + if !currentOAuthInfo.UseRefreshToken { + return nil + } + + accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) + + hasIdTokenExpired := false + idTokenExpires := time.Time{} + + if !idTokenExpiry.IsZero() { + idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) + hasIdTokenExpired = idTokenExpires.Before(time.Now()) + } + // token has not expired, so we don't have to refresh it + if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired { + // cache the token check, so we don't perform it on every request + s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)) + return nil + } + // FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update + updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := s.service.TryTokenRefresh(updateCtx, token); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { + s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err) + } + + if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil { + s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err) + } + + if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil { + s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err) + } + + return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err) + } + + return nil +} + +const maxOAuthTokenCacheTTL = 10 * time.Minute + +func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration { + if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { + return maxOAuthTokenCacheTTL + } + + min := func(a, b time.Duration) time.Duration { + if a <= b { + return a + } + return b + } + + if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() { + return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL) + } + + if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { + return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL) + } + + return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL) +} + +// getIDTokenExpiry extracts the expiry time from the ID token +func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) { + if token.OAuthIdToken == "" { + return time.Time{}, nil + } + + parsedToken, err := jwt.ParseSigned(token.OAuthIdToken) + if err != nil { + return time.Time{}, fmt.Errorf("error parsing id token: %w", err) + } + + type Claims struct { + Exp int64 `json:"exp"` + } + var claims Claims + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err) + } + + return time.Unix(claims.Exp, 0), nil +} diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go new file mode 100644 index 00000000000..7d479e15ee4 --- /dev/null +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go @@ -0,0 +1,258 @@ +package sync + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/infra/localcache" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/socialtest" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" +) + +func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { + type testCase struct { + desc string + identity *authn.Identity + oauthInfo *social.OAuthInfo + + expectedHasEntryToken *login.UserAuth + expectHasEntryCalled bool + + expectedTryRefreshErr error + expectTryRefreshTokenCalled bool + + expectRevokeTokenCalled bool + expectInvalidateOauthTokensCalled bool + + expectedErr error + } + + tests := []testCase{ + { + desc: "should skip sync when identity is not a user", + identity: &authn.Identity{ID: "service-account:1"}, + }, + { + desc: "should skip sync when identity is a user but is not authenticated with session token", + identity: &authn.Identity{ID: "user:1"}, + }, + { + desc: "should skip sync when user has session but is not authenticated with oauth", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + }, + { + desc: "should skip sync for when access token don't have expire time", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectedHasEntryToken: &login.UserAuth{}, + }, + { + desc: "should skip sync when access token has no expired yet", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, + }, + { + desc: "should skip sync when access token has no expired yet", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, + }, + { + desc: "should refresh access token when is has expired", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectTryRefreshTokenCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, + }, + { + desc: "should invalidate access token and session token if access token can't be refreshed", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectedTryRefreshErr: errors.New("some err"), + expectTryRefreshTokenCalled: true, + expectInvalidateOauthTokensCalled: true, + expectRevokeTokenCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, + expectedErr: authn.ErrExpiredAccessToken, + }, { + desc: "should skip sync when use_refresh_token is disabled", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule}, + expectHasEntryCalled: true, + expectTryRefreshTokenCalled: false, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, + oauthInfo: &social.OAuthInfo{UseRefreshToken: false}, + }, + { + desc: "should refresh access token when ID token has expired", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectTryRefreshTokenCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))}, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var ( + hasEntryCalled bool + tryRefreshCalled bool + invalidateTokensCalled bool + revokeTokenCalled bool + ) + + service := &oauthtokentest.MockOauthTokenService{ + HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) { + hasEntryCalled = true + return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil + }, + InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error { + invalidateTokensCalled = true + return nil + }, + TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error { + tryRefreshCalled = true + return tt.expectedTryRefreshErr + }, + } + + sessionService := &authtest.FakeUserAuthTokenService{ + RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error { + revokeTokenCalled = true + return nil + }, + } + + if tt.oauthInfo == nil { + tt.oauthInfo = &social.OAuthInfo{ + UseRefreshToken: true, + } + } + + socialService := &socialtest.FakeSocialService{ + ExpectedAuthInfoProvider: tt.oauthInfo, + } + + sync := &OAuthTokenSync{ + log: log.NewNopLogger(), + cache: localcache.New(0, 0), + service: service, + sessionService: sessionService, + socialService: socialService, + } + + err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil) + assert.ErrorIs(t, err, tt.expectedErr) + assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled) + assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled) + assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled) + assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled) + }) + } +} + +// fakeIDToken is used to create a fake invalid token to verify expiry logic +func fakeIDToken(t *testing.T, expiryDate time.Time) string { + type Header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + } + type Payload struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Exp int64 `json:"exp"` + } + + header, err := json.Marshal(Header{Kid: "123", Alg: "none"}) + require.NoError(t, err) + u := expiryDate.UTC().Unix() + payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u}) + require.NoError(t, err) + + fakeSignature := []byte("6ICJm") + return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature)) +} + +func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) { + defaultTime := time.Now() + tests := []struct { + name string + accessTokenExpiry time.Time + idTokenExpiry time.Time + want time.Duration + }{ + { + name: "should return maxOAuthTokenCacheTTL when no expiry is given", + accessTokenExpiry: time.Time{}, + idTokenExpiry: time.Time{}, + + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl", + accessTokenExpiry: time.Time{}, + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl", + accessTokenExpiry: time.Time{}, + idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: time.Time{}, + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given", + accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: time.Time{}, + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry", + accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + want: maxOAuthTokenCacheTTL, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry) + + assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second)) + }) + } +} diff --git a/pkg/services/authn/authnimpl/sync/user_sync.go b/pkg/services/authn/authnimpl/sync/user_sync.go index dda54e8a2f0..0b0a5b09b91 100644 --- a/pkg/services/authn/authnimpl/sync/user_sync.go +++ b/pkg/services/authn/authnimpl/sync/user_sync.go @@ -394,5 +394,4 @@ func syncSignedInUserToIdentity(usr *user.SignedInUser, identity *authn.Identity identity.LastSeenAt = usr.LastSeenAt identity.IsDisabled = usr.IsDisabled identity.IsGrafanaAdmin = &usr.IsGrafanaAdmin - identity.AuthenticatedBy = usr.AuthenticatedBy } diff --git a/pkg/services/authn/clients/session.go b/pkg/services/authn/clients/session.go index 135a801c856..c97a87d27f1 100644 --- a/pkg/services/authn/clients/session.go +++ b/pkg/services/authn/clients/session.go @@ -3,23 +3,14 @@ package clients import ( "context" "errors" - "fmt" "net/url" - "strings" "time" - "github.com/go-jose/go-jose/v3/jwt" - - "github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/network" - "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/oauthtoken" - "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) @@ -27,31 +18,21 @@ import ( var _ authn.HookClient = new(Session) var _ authn.ContextAwareClient = new(Session) -func ProvideSession( - cfg *setting.Cfg, features *featuremgmt.FeatureManager, sessionService auth.UserTokenService, - oauthTokenService oauthtoken.OAuthTokenService, socialService social.Service, -) *Session { +func ProvideSession(cfg *setting.Cfg, sessionService auth.UserTokenService, + features *featuremgmt.FeatureManager) *Session { return &Session{ - cfg: cfg, - features: features, - sessionService: sessionService, - oauthTokenService: oauthTokenService, - socialService: socialService, - log: log.New(authn.ClientSession), - cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute), + cfg: cfg, + features: features, + sessionService: sessionService, + log: log.New(authn.ClientSession), } } type Session struct { - log log.Logger - cfg *setting.Cfg - features *featuremgmt.FeatureManager - - socialService social.Service - sessionService auth.UserTokenService - oauthTokenService oauthtoken.OAuthTokenService - - cache *localcache.CacheService + cfg *setting.Cfg + features *featuremgmt.FeatureManager + sessionService auth.UserTokenService + log log.Logger } func (s *Session) Name() string { @@ -107,19 +88,7 @@ func (s *Session) Priority() uint { } func (s *Session) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { - if identity.SessionToken == nil { - return nil - } - - if err := s.rotateTokenHook(ctx, identity, r); err != nil { - return err - } - - return s.syncOAuthTokenHook(ctx, identity, r) -} - -func (s *Session) rotateTokenHook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { - if s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) { + if identity.SessionToken == nil || s.features.IsEnabled(featuremgmt.FlagClientTokenRotation) { return nil } @@ -154,143 +123,3 @@ func (s *Session) rotateTokenHook(ctx context.Context, identity *authn.Identity, return nil } - -func (s *Session) syncOAuthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { - if !s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { - return nil - } - - namespace, id := identity.NamespacedID() - // only perform oauth token check if identity is a user - if namespace != authn.NamespaceUser { - return nil - } - - // if we recently have performed this it would be cached, so we can skip the hook - if _, ok := s.cache.Get(identity.ID); ok { - return nil - } - - token, exists, _ := s.oauthTokenService.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id}) - // user is not authenticated through oauth so skip further checks - if !exists { - // if user is not authenticated through oauth we can skip this check by adding the id to the cache - s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL) - return nil - } - - // get the token's auth provider (f.e. azuread) - provider := strings.TrimPrefix(token.AuthModule, "oauth_") - currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider) - if currentOAuthInfo == nil { - s.log.Warn("OAuth provider not found", "provider", provider) - return nil - } - - // if refresh token handling is disabled for this provider, we can skip the hook - if !currentOAuthInfo.UseRefreshToken { - // refresh token is not configured for provider so we can skip this check by adding the id to the cache - s.cache.Set(identity.ID, struct{}{}, maxOAuthTokenCacheTTL) - return nil - } - - idTokenExpiry, err := getIDTokenExpiry(token) - if err != nil { - s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err) - } - - // token has no expire time configured, so we don't have to refresh it - if token.OAuthExpiry.IsZero() { - // cache the token check, so we don't perform it on every request - s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry)) - return nil - } - - accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) - - hasIdTokenExpired := false - idTokenExpires := time.Time{} - - if !idTokenExpiry.IsZero() { - idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) - hasIdTokenExpired = idTokenExpires.Before(time.Now()) - } - - // token has not expired, so we don't have to refresh it - if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired { - // cache the token check, so we don't perform it on every request - s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)) - return nil - } - // FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update - updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - if err := s.oauthTokenService.TryTokenRefresh(updateCtx, token); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { - s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err) - } - - if err := s.oauthTokenService.InvalidateOAuthTokens(ctx, token); err != nil { - s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err) - } - - if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil { - s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err) - } - - return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err) - } - - return nil -} - -const maxOAuthTokenCacheTTL = 10 * time.Minute - -func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration { - if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { - return maxOAuthTokenCacheTTL - } - - min := func(a, b time.Duration) time.Duration { - if a <= b { - return a - } - return b - } - - if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() { - return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL) - } - - if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { - return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL) - } - - return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL) -} - -// getIDTokenExpiry extracts the expiry time from the ID token -func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) { - if token.OAuthIdToken == "" { - return time.Time{}, nil - } - - parsedToken, err := jwt.ParseSigned(token.OAuthIdToken) - if err != nil { - return time.Time{}, fmt.Errorf("error parsing id token: %w", err) - } - - type Claims struct { - Exp int64 `json:"exp"` - } - var claims Claims - if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { - return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err) - } - - return time.Unix(claims.Exp, 0), nil -} diff --git a/pkg/services/authn/clients/session_test.go b/pkg/services/authn/clients/session_test.go index f15fed8e749..8dd7555cd26 100644 --- a/pkg/services/authn/clients/session_test.go +++ b/pkg/services/authn/clients/session_test.go @@ -2,10 +2,6 @@ package clients import ( "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" "net" "net/http" "testing" @@ -14,16 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/login/socialtest" "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" - "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) @@ -38,7 +29,7 @@ func TestSession_Test(t *testing.T) { cfg := setting.NewCfg() cfg.LoginCookieName = "" cfg.LoginMaxLifetime = 20 * time.Second - s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{}, nil, nil) + s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{}, featuremgmt.WithFeatures()) disabled := s.Test(context.Background(), &authn.Request{HTTPRequest: validHTTPReq}) assert.False(t, disabled) @@ -154,7 +145,7 @@ func TestSession_Authenticate(t *testing.T) { cfg.LoginCookieName = cookieName cfg.TokenRotationIntervalMinutes = 10 cfg.LoginMaxLifetime = 20 * time.Second - s := ProvideSession(cfg, tt.fields.features, tt.fields.sessionService, nil, nil) + s := ProvideSession(cfg, tt.fields.sessionService, tt.fields.features) got, err := s.Authenticate(context.Background(), tt.args.r) require.True(t, (err != nil) == tt.wantErr, err) @@ -184,17 +175,17 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) { f.Status = statusCode } -func TestSession_RotateSessionHook(t *testing.T) { +func TestSession_Hook(t *testing.T) { t.Run("should rotate token", func(t *testing.T) { cfg := setting.NewCfg() cfg.LoginCookieName = "grafana-session" cfg.LoginMaxLifetime = 20 * time.Second - s := ProvideSession(cfg, featuremgmt.WithFeatures(), &authtest.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(_ context.Context, token *auth.UserToken, _ net.IP, _ string) (bool, *auth.UserToken, error) { + s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { token.UnhashedToken = "new-token" return true, token, nil }, - }, nil, nil) + }, featuremgmt.WithFeatures()) sampleID := &authn.Identity{ SessionToken: &auth.UserToken{ @@ -215,7 +206,7 @@ func TestSession_RotateSessionHook(t *testing.T) { Resp: web.NewResponseWriter(http.MethodConnect, mockResponseWriter), } - err := s.rotateTokenHook(context.Background(), sampleID, resp) + err := s.Hook(context.Background(), sampleID, resp) require.NoError(t, err) resp.Resp.WriteHeader(201) @@ -228,7 +219,7 @@ func TestSession_RotateSessionHook(t *testing.T) { }) t.Run("should not rotate token with feature flag", func(t *testing.T) { - s := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation), nil, nil, nil) + s := ProvideSession(setting.NewCfg(), nil, featuremgmt.WithFeatures(featuremgmt.FlagClientTokenRotation)) req := &authn.Request{} identity := &authn.Identity{} @@ -236,226 +227,3 @@ func TestSession_RotateSessionHook(t *testing.T) { require.NoError(t, err) }) } - -func TestSession_SyncOAuthTokenHook(t *testing.T) { - type testCase struct { - desc string - identity *authn.Identity - oauthInfo *social.OAuthInfo - - expectedHasEntryToken *login.UserAuth - expectHasEntryCalled bool - - expectedTryRefreshErr error - expectTryRefreshTokenCalled bool - - expectRevokeTokenCalled bool - expectInvalidateOauthTokensCalled bool - - expectedErr error - } - - tests := []testCase{ - { - desc: "should skip sync when identity is not a user", - identity: &authn.Identity{ID: "service-account:1"}, - }, - { - desc: "should skip sync when user has session but is not authenticated with oauth", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - }, - { - desc: "should skip sync for when access token don't have expire time", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectedHasEntryToken: &login.UserAuth{}, - }, - { - desc: "should skip sync when access token has no expired yet", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, - }, - { - desc: "should skip sync when access token has no expired yet", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, - }, - { - desc: "should refresh access token when is has expired", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectTryRefreshTokenCalled: true, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, - }, - { - desc: "should invalidate access token and session token if access token can't be refreshed", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectedTryRefreshErr: errors.New("some err"), - expectTryRefreshTokenCalled: true, - expectInvalidateOauthTokensCalled: true, - expectRevokeTokenCalled: true, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, - expectedErr: authn.ErrExpiredAccessToken, - }, { - desc: "should skip sync when use_refresh_token is disabled", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule}, - expectHasEntryCalled: true, - expectTryRefreshTokenCalled: false, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, - oauthInfo: &social.OAuthInfo{UseRefreshToken: false}, - }, - { - desc: "should refresh access token when ID token has expired", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, - expectHasEntryCalled: true, - expectTryRefreshTokenCalled: true, - expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))}, - }, - } - - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - var ( - hasEntryCalled bool - tryRefreshCalled bool - invalidateTokensCalled bool - revokeTokenCalled bool - ) - - oauthTokenService := &oauthtokentest.MockOauthTokenService{ - HasOAuthEntryFunc: func(_ context.Context, _ identity.Requester) (*login.UserAuth, bool, error) { - hasEntryCalled = true - return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil - }, - InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error { - invalidateTokensCalled = true - return nil - }, - TryTokenRefreshFunc: func(_ context.Context, _ *login.UserAuth) error { - tryRefreshCalled = true - return tt.expectedTryRefreshErr - }, - } - - sessionService := &authtest.FakeUserAuthTokenService{ - RevokeTokenProvider: func(_ context.Context, _ *auth.UserToken, _ bool) error { - revokeTokenCalled = true - return nil - }, - } - - if tt.oauthInfo == nil { - tt.oauthInfo = &social.OAuthInfo{ - UseRefreshToken: true, - } - } - - socialService := &socialtest.FakeSocialService{ - ExpectedAuthInfoProvider: tt.oauthInfo, - } - - client := ProvideSession(setting.NewCfg(), featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck), sessionService, oauthTokenService, socialService) - - err := client.syncOAuthTokenHook(context.Background(), tt.identity, nil) - assert.ErrorIs(t, err, tt.expectedErr) - assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled) - assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled) - assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled) - assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled) - }) - } -} - -// fakeIDToken is used to create sa fake invalid token to verify expiry logic -func fakeIDToken(t *testing.T, expiryDate time.Time) string { - type Header struct { - Kid string `json:"kid"` - Alg string `json:"alg"` - } - type Payload struct { - Iss string `json:"iss"` - Sub string `json:"sub"` - Exp int64 `json:"exp"` - } - - header, err := json.Marshal(Header{Kid: "123", Alg: "none"}) - require.NoError(t, err) - u := expiryDate.UTC().Unix() - payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u}) - require.NoError(t, err) - - fakeSignature := []byte("6ICJm") - return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature)) -} - -func TestGetOAuthTokenCacheTTL(t *testing.T) { - defaultTime := time.Now() - tests := []struct { - name string - accessTokenExpiry time.Time - idTokenExpiry time.Time - want time.Duration - }{ - { - name: "should return maxOAuthTokenCacheTTL when no expiry is given", - accessTokenExpiry: time.Time{}, - idTokenExpiry: time.Time{}, - - want: maxOAuthTokenCacheTTL, - }, - { - name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl", - accessTokenExpiry: time.Time{}, - idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - - want: maxOAuthTokenCacheTTL, - }, - { - name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl", - accessTokenExpiry: time.Time{}, - idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), - want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), - }, - { - name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given", - accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - idTokenExpiry: time.Time{}, - want: maxOAuthTokenCacheTTL, - }, - { - name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given", - accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), - idTokenExpiry: time.Time{}, - want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), - }, - { - name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry", - accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), - idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), - }, - { - name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry", - accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL), - want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)), - }, - { - name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl", - accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), - want: maxOAuthTokenCacheTTL, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry) - - assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second)) - }) - } -}