grafana/pkg/services/authn/authnimpl/sync/oauth_token_sync.go
Misi bd2191c158
Auth: OAuth token sync improvements (#75943)
* Add metric, improve token refresh

* changes

* handle ctx cancelled

* Fix import order
2023-10-05 11:19:43 +02:00

175 lines
5.4 KiB
Go

package sync
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
)
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
return &OAuthTokenSync{
log.New("oauth_token.sync"),
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
service,
sessionService,
socialService,
}
}
type OAuthTokenSync struct {
log log.Logger
cache *localcache.CacheService
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
socialService social.Service
}
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
namespace, id := identity.NamespacedID()
// only perform oauth token check if identity is a user
if namespace != authn.NamespaceUser {
return nil
}
// not authenticated through session tokens, so we can skip this hook
if identity.SessionToken == nil {
return nil
}
// if we recently have performed this it would be cached, so we can skip the hook
if _, ok := s.cache.Get(identity.ID); ok {
return nil
}
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
// user is not authenticated through oauth so skip further checks
if !exists {
return nil
}
idTokenExpiry, err := getIDTokenExpiry(token)
if err != nil {
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
}
// token has no expire time configured, so we don't have to refresh it
if token.OAuthExpiry.IsZero() {
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
return nil
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
s.log.Warn("OAuth provider not found", "provider", provider)
return nil
}
// if refresh token handling is disabled for this provider, we can skip the hook
if !currentOAuthInfo.UseRefreshToken {
return nil
}
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasIdTokenExpired := false
idTokenExpires := time.Time{}
if !idTokenExpiry.IsZero() {
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasIdTokenExpired = idTokenExpires.Before(time.Now())
}
// token has not expired, so we don't have to refresh it
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
return nil
}
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
}
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
}
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
}
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
}
return nil
}
const maxOAuthTokenCacheTTL = 10 * time.Minute
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
}
min := func(a, b time.Duration) time.Duration {
if a <= b {
return a
}
return b
}
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
}
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
}
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
if token.OAuthIdToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}
type Claims struct {
Exp int64 `json:"exp"`
}
var claims Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
}
return time.Unix(claims.Exp, 0), nil
}