AuthN: Add post auth hook for oauth token refresh (#61608)

* AuthN: rename package to sync

* AuthN: rename sync files

* Ouath: Add mock for OauthTokenService

* AuthN: Implement access token refresh hook

* AuthN: remove feature check from hook

* AuthN: register post auth hook for oauth token refresh
This commit is contained in:
Karl Persson 2023-01-18 10:47:09 +01:00 committed by GitHub
parent 29119a7d08
commit 412d80b498
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 278 additions and 5 deletions

View File

@ -14,10 +14,12 @@ import (
"github.com/grafana/grafana/pkg/services/apikey"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
sync "github.com/grafana/grafana/pkg/services/authn/authnimpl/usersync"
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
"github.com/grafana/grafana/pkg/services/authn/clients"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/services/rendering"
@ -43,6 +45,7 @@ func ProvideService(
userProtectionService login.UserProtectionService,
loginAttempts loginattempt.Service, quotaService quota.Service,
authInfoService login.AuthInfoService, renderService rendering.Service,
features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService,
) *Service {
s := &Service{
log: log.New("authn.service"),
@ -111,6 +114,10 @@ func ProvideService(
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen)
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen)
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
s.RegisterPostAuthHook(sync.ProvideOauthTokenSync(oauthTokenService, sessionService).SyncOauthToken)
}
return s
}

View File

@ -0,0 +1,79 @@
package sync
import (
"context"
"errors"
"time"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/errutil"
)
var (
errExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token")
)
func ProvideOauthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService) *OauthTokenSync {
return &OauthTokenSync{
log.New("oauth_token.sync"),
service,
sessionService,
}
}
type OauthTokenSync struct {
log log.Logger
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
}
func (s *OauthTokenSync) SyncOauthToken(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
namespace, id := identity.NamespacedID()
// only perform oauth token check if identity is a user
if namespace != authn.NamespaceUser {
return nil
}
// not authenticated through session tokens, so we can skip this hook
if identity.SessionToken == nil {
return nil
}
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
// user is not authenticated through oauth so skip further checks
if !exists {
return nil
}
// token has no expire time configured, so we don't have to refresh it
if token.OAuthExpiry.IsZero() {
return nil
}
// token has not expired, so we don't have to refresh it
if !token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(time.Now()) {
return nil
}
if err := s.service.TryTokenRefresh(ctx, token); err != nil {
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
s.log.FromContext(ctx).Error("could not refresh oauth access token for user", "userId", id, "err", err)
}
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
s.log.FromContext(ctx).Error("could not invalidate OAuth tokens", "userId", id, "err", err)
}
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
s.log.FromContext(ctx).Error("could not revoke token", "userId", id, "tokenId", identity.SessionToken.Id, "err", err)
}
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", auth.ErrInvalidSessionToken)
}
return nil
}

View File

@ -0,0 +1,134 @@
package sync
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
)
func TestOauthTokenSync_SyncOauthToken(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
expectedHasEntryToken *models.UserAuth
expectHasEntryCalled bool
expectedTryRefreshErr error
expectTryRefreshTokenCalled bool
expectRevokeTokenCalled bool
expectInvalidateOauthTokensCalled bool
expectedErr error
}
tests := []testCase{
{
desc: "should skip sync when identity is not a user",
identity: &authn.Identity{ID: "service-account:1"},
},
{
desc: "should skip sync when identity is a user but is not authenticated with session token",
identity: &authn.Identity{ID: "user:1"},
},
{
desc: "should skip sync when user has session but is not authenticated with oauth",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
},
{
desc: "should skip sync for when access token don't have expire time",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &models.UserAuth{},
},
{
desc: "should skip sync when access token has no expired yet",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
{
desc: "should skip sync when access token has no expired yet",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
{
desc: "should refresh access token when is has expired",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
},
{
desc: "should invalidate access token and session token if access token can't be refreshed",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedTryRefreshErr: errors.New("some err"),
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: true,
expectRevokeTokenCalled: true,
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: errExpiredAccessToken,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
hasEntryCalled bool
tryRefreshCalled bool
invalidateTokensCalled bool
revokeTokenCalled bool
)
service := &oauthtokentest.MockOauthTokenService{
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
hasEntryCalled = true
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
},
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *models.UserAuth) error {
invalidateTokensCalled = true
return nil
},
TryTokenRefreshFunc: func(ctx context.Context, usr *models.UserAuth) error {
tryRefreshCalled = true
return tt.expectedTryRefreshErr
},
}
sessionService := &authtest.FakeUserAuthTokenService{
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
revokeTokenCalled = true
return nil
},
}
sync := &OauthTokenSync{
log: log.NewNopLogger(),
service: service,
sessionService: sessionService,
}
err := sync.SyncOauthToken(context.Background(), tt.identity, nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
})
}
}

View File

@ -1,4 +1,4 @@
package usersync
package sync
import (
"context"

View File

@ -1,4 +1,4 @@
package usersync
package sync
import (
"context"

View File

@ -1,4 +1,4 @@
package usersync
package sync
import (
"context"

View File

@ -1,4 +1,4 @@
package usersync
package sync
import (
"context"

View File

@ -0,0 +1,53 @@
package oauthtokentest
import (
"context"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/user"
"golang.org/x/oauth2"
)
type MockOauthTokenService struct {
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *models.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr *models.UserAuth) error
}
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
if m.GetCurrentOauthTokenFunc != nil {
return m.GetCurrentOauthTokenFunc(ctx, usr)
}
return nil
}
func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
if m.IsOAuthPassThruEnabledFunc != nil {
return m.IsOAuthPassThruEnabledFunc(ds)
}
return false
}
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
if m.HasOAuthEntryFunc != nil {
return m.HasOAuthEntryFunc(ctx, usr)
}
return nil, false, nil
}
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
if m.InvalidateOAuthTokensFunc != nil {
return m.InvalidateOAuthTokensFunc(ctx, usr)
}
return nil
}
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
if m.TryTokenRefreshFunc != nil {
return m.TryTokenRefreshFunc(ctx, usr)
}
return nil
}