mirror of
https://github.com/grafana/grafana.git
synced 2025-02-14 17:43:35 -06:00
Auth: OAuth token sync improvements (#75943)
* Add metric, improve token refresh * changes * handle ctx cancelled * Fix import order
This commit is contained in:
parent
f08ad95c59
commit
bd2191c158
@ -100,18 +100,24 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.service.TryTokenRefresh(ctx, token); err != nil {
|
||||
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
s.log.FromContext(ctx).Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
@ -33,6 +34,8 @@ type Service struct {
|
||||
SocialService social.Service
|
||||
AuthInfoService login.AuthInfoService
|
||||
singleFlightGroup *singleflight.Group
|
||||
|
||||
tokenRefreshDuration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
type OAuthTokenService interface {
|
||||
@ -43,12 +46,13 @@ type OAuthTokenService interface {
|
||||
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
||||
}
|
||||
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service {
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
|
||||
return &Service{
|
||||
Cfg: cfg,
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: new(singleflight.Group),
|
||||
Cfg: cfg,
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: new(singleflight.Group),
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
|
||||
}
|
||||
}
|
||||
|
||||
@ -212,8 +216,12 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use
|
||||
|
||||
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
||||
|
||||
start := time.Now()
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
duration := time.Since(start)
|
||||
o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to retrieve oauth access token",
|
||||
"provider", usr.AuthModule, "userId", usr.UserId, "error", err)
|
||||
@ -254,6 +262,20 @@ func IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
|
||||
}
|
||||
|
||||
func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus.HistogramVec {
|
||||
tokenRefreshDuration := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "grafana",
|
||||
Subsystem: "oauth",
|
||||
Name: "token_refresh_fetch_duration_seconds",
|
||||
Help: "Time taken to fetch access token using refresh token",
|
||||
},
|
||||
[]string{"auth_provider", "success"})
|
||||
if registerer != nil {
|
||||
registerer.MustRegister(tokenRefreshDuration)
|
||||
}
|
||||
return tokenRefreshDuration
|
||||
}
|
||||
|
||||
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
||||
func tokensEq(t1, t2 *oauth2.Token) bool {
|
||||
return t1.AccessToken == t2.AccessToken &&
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/oauth2"
|
||||
@ -231,10 +232,11 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social
|
||||
authInfoStore := &FakeAuthInfoStore{}
|
||||
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
|
||||
return &Service{
|
||||
Cfg: setting.NewCfg(),
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
Cfg: setting.NewCfg(),
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
|
||||
}, authInfoStore, socialConnector
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user