OAuth: Introduce user_refresh_token setting and make it default for the selected providers (#71533)

* First changes

* WIP docs

* Align current tests

* Add test for UseRefreshToken

* Update docs

* Fix

* Remove unnecessary AuthCodeURL from generic_oauth

* Change GitHub to disable use_refresh_token by default
This commit is contained in:
Misi
2023-07-14 14:03:01 +02:00
committed by GitHub
parent 1f3aa099d5
commit dcf26564db
15 changed files with 333 additions and 96 deletions

View File

@@ -3,10 +3,12 @@ package sync
import (
"context"
"errors"
"strings"
"time"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/oauthtoken"
@@ -15,15 +17,19 @@ import (
)
var (
errExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token")
errExpiredAccessToken = errutil.NewBase(
errutil.StatusUnauthorized,
"oauth.expired-token",
errutil.WithPublicMessage("OAuth access token expired"))
)
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService) *OAuthTokenSync {
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
return &OAuthTokenSync{
log.New("oauth_token.sync"),
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
service,
sessionService,
socialService,
}
}
@@ -32,6 +38,7 @@ type OAuthTokenSync struct {
cache *localcache.CacheService
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
socialService social.Service
}
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
@@ -64,6 +71,19 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
s.log.Warn("OAuth provider not found", "provider", provider)
return nil
}
// if refresh token handling is disabled for this provider, we can skip the hook
if !currentOAuthInfo.UseRefreshToken {
return nil
}
expires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
// token has not expired, so we don't have to refresh it
if !expires.Before(time.Now()) {
@@ -78,14 +98,14 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
}
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
s.log.FromContext(ctx).Error("Failed invalidate OAuth tokens", "id", identity.ID, "error", err)
s.log.FromContext(ctx).Error("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
}
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
}
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", auth.ErrInvalidSessionToken)
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
}
return nil

View File

@@ -10,6 +10,8 @@ import (
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/socialtest"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/authn"
@@ -20,8 +22,9 @@ import (
func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
desc string
identity *authn.Identity
oauthInfo *social.OAuthInfo
expectedHasEntryToken *login.UserAuth
expectHasEntryCalled bool
@@ -84,6 +87,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectRevokeTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: errExpiredAccessToken,
}, {
desc: "should skip sync when use_refresh_token is disabled",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: false,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
},
}
@@ -118,11 +128,22 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
},
}
if tt.oauthInfo == nil {
tt.oauthInfo = &social.OAuthInfo{
UseRefreshToken: true,
}
}
socialService := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthInfo,
}
sync := &OAuthTokenSync{
log: log.NewNopLogger(),
cache: localcache.New(0, 0),
service: service,
sessionService: sessionService,
socialService: socialService,
}
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)