Auth: OAuth token sync improvements (#75943)

* Add metric, improve token refresh

* changes

* handle ctx cancelled

* Fix import order
This commit is contained in:
Misi 2023-10-05 11:19:43 +02:00 committed by GitHub
parent f08ad95c59
commit bd2191c158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 13 deletions

View File

@ -100,18 +100,24 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
return nil
}
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := s.service.TryTokenRefresh(ctx, token); err != nil {
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
s.log.FromContext(ctx).Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
}
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
s.log.FromContext(ctx).Error("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
}
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
}
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)

View File

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
@ -33,6 +34,8 @@ type Service struct {
SocialService social.Service
AuthInfoService login.AuthInfoService
singleFlightGroup *singleflight.Group
tokenRefreshDuration *prometheus.HistogramVec
}
type OAuthTokenService interface {
@ -43,12 +46,13 @@ type OAuthTokenService interface {
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
}
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service {
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
return &Service{
Cfg: cfg,
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: new(singleflight.Group),
Cfg: cfg,
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: new(singleflight.Group),
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
}
}
@ -212,8 +216,12 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use
persistedToken := buildOAuthTokenFromAuthInfo(usr)
start := time.Now()
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
duration := time.Since(start)
o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
if err != nil {
logger.Error("Failed to retrieve oauth access token",
"provider", usr.AuthModule, "userId", usr.UserId, "error", err)
@ -254,6 +262,20 @@ func IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
}
func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus.HistogramVec {
tokenRefreshDuration := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "grafana",
Subsystem: "oauth",
Name: "token_refresh_fetch_duration_seconds",
Help: "Time taken to fetch access token using refresh token",
},
[]string{"auth_provider", "success"})
if registerer != nil {
registerer.MustRegister(tokenRefreshDuration)
}
return tokenRefreshDuration
}
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
func tokensEq(t1, t2 *oauth2.Token) bool {
return t1.AccessToken == t2.AccessToken &&

View File

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
@ -231,10 +232,11 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social
authInfoStore := &FakeAuthInfoStore{}
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}, authInfoStore, socialConnector
}