mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Handle when access token has already been refreshed in OAuth token sync (#77118)
* Use singleflight to prevent logging error if the token has already been refreshed * Change order of error checks * align tests, change error name * Change sf key * Update based on the review * refactor
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||
@@ -26,6 +26,7 @@ func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService
|
||||
service,
|
||||
sessionService,
|
||||
socialService,
|
||||
new(singleflight.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,10 +36,11 @@ type OAuthTokenSync struct {
|
||||
service oauthtoken.OAuthTokenService
|
||||
sessionService auth.UserTokenService
|
||||
socialService social.Service
|
||||
sf *singleflight.Group
|
||||
}
|
||||
|
||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||
namespace, id := identity.NamespacedID()
|
||||
namespace, _ := identity.NamespacedID()
|
||||
// only perform oauth token check if identity is a user
|
||||
if namespace != authn.NamespaceUser {
|
||||
return nil
|
||||
@@ -54,7 +56,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||
token, exists, _ := s.service.HasOAuthEntry(ctx, identity)
|
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists {
|
||||
return nil
|
||||
@@ -85,42 +87,62 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
accessTokenExpires, hasAccessTokenExpired := getExpiryWithSkew(token.OAuthExpiry)
|
||||
|
||||
hasIdTokenExpired := false
|
||||
idTokenExpires := time.Time{}
|
||||
|
||||
if !idTokenExpiry.IsZero() {
|
||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExpiry)
|
||||
}
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||
}
|
||||
_, err, _ = s.sf.Do(identity.ID, func() (interface{}, error) {
|
||||
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
|
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
}
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
if refreshErr := s.service.TryTokenRefresh(updateCtx, token); refreshErr != nil {
|
||||
if errors.Is(refreshErr, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
token, _, err := s.service.HasOAuthEntry(ctx, identity)
|
||||
if err != nil {
|
||||
s.log.Error("Failed to get OAuth entry for verifying if token has already been refreshed", "id", identity.ID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if the access token has already been refreshed by another request (for example in HA scenario)
|
||||
tokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
if !tokenExpires.Before(time.Now()) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", refreshErr)
|
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
return nil, refreshErr
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return authn.ErrExpiredAccessToken.Errorf("OAuth access token could not be refreshed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -172,3 +194,9 @@ func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
|
||||
adjustedExpiry = expiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
@@ -75,7 +76,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when is has expired",
|
||||
desc: "should refresh access token when it has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
@@ -155,6 +156,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
service: service,
|
||||
sessionService: sessionService,
|
||||
socialService: socialService,
|
||||
sf: new(singleflight.Group),
|
||||
}
|
||||
|
||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
||||
|
||||
Reference in New Issue
Block a user