mirror of
https://github.com/grafana/grafana.git
synced 2024-11-25 18:30:41 -06:00
AuthN: Add post auth hook for oauth token refresh (#61608)
* AuthN: rename package to sync * AuthN: rename sync files * Ouath: Add mock for OauthTokenService * AuthN: Implement access token refresh hook * AuthN: remove feature check from hook * AuthN: register post auth hook for oauth token refresh
This commit is contained in:
parent
29119a7d08
commit
412d80b498
@ -14,10 +14,12 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
sync "github.com/grafana/grafana/pkg/services/authn/authnimpl/usersync"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
|
||||
"github.com/grafana/grafana/pkg/services/authn/clients"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/quota"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
@ -43,6 +45,7 @@ func ProvideService(
|
||||
userProtectionService login.UserProtectionService,
|
||||
loginAttempts loginattempt.Service, quotaService quota.Service,
|
||||
authInfoService login.AuthInfoService, renderService rendering.Service,
|
||||
features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService,
|
||||
) *Service {
|
||||
s := &Service{
|
||||
log: log.New("authn.service"),
|
||||
@ -111,6 +114,10 @@ func ProvideService(
|
||||
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen)
|
||||
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen)
|
||||
|
||||
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
||||
s.RegisterPostAuthHook(sync.ProvideOauthTokenSync(oauthTokenService, sessionService).SyncOauthToken)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
|
79
pkg/services/authn/authnimpl/sync/oauth_token_sync.go
Normal file
79
pkg/services/authn/authnimpl/sync/oauth_token_sync.go
Normal file
@ -0,0 +1,79 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token")
|
||||
)
|
||||
|
||||
func ProvideOauthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService) *OauthTokenSync {
|
||||
return &OauthTokenSync{
|
||||
log.New("oauth_token.sync"),
|
||||
service,
|
||||
sessionService,
|
||||
}
|
||||
}
|
||||
|
||||
type OauthTokenSync struct {
|
||||
log log.Logger
|
||||
service oauthtoken.OAuthTokenService
|
||||
sessionService auth.UserTokenService
|
||||
}
|
||||
|
||||
func (s *OauthTokenSync) SyncOauthToken(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||
namespace, id := identity.NamespacedID()
|
||||
// only perform oauth token check if identity is a user
|
||||
if namespace != authn.NamespaceUser {
|
||||
return nil
|
||||
}
|
||||
|
||||
// not authenticated through session tokens, so we can skip this hook
|
||||
if identity.SessionToken == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(time.Now()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.service.TryTokenRefresh(ctx, token); err != nil {
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
s.log.FromContext(ctx).Error("could not refresh oauth access token for user", "userId", id, "err", err)
|
||||
}
|
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||
s.log.FromContext(ctx).Error("could not invalidate OAuth tokens", "userId", id, "err", err)
|
||||
}
|
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||
s.log.FromContext(ctx).Error("could not revoke token", "userId", id, "tokenId", identity.SessionToken.Id, "err", err)
|
||||
}
|
||||
|
||||
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", auth.ErrInvalidSessionToken)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
134
pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go
Normal file
134
pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
func TestOauthTokenSync_SyncOauthToken(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
identity *authn.Identity
|
||||
|
||||
expectedHasEntryToken *models.UserAuth
|
||||
expectHasEntryCalled bool
|
||||
|
||||
expectedTryRefreshErr error
|
||||
expectTryRefreshTokenCalled bool
|
||||
|
||||
expectRevokeTokenCalled bool
|
||||
expectInvalidateOauthTokensCalled bool
|
||||
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should skip sync when identity is not a user",
|
||||
identity: &authn.Identity{ID: "service-account:1"},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should skip sync for when access token don't have expire time",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &models.UserAuth{},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should skip sync when access token has no expired yet",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when is has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectedTryRefreshErr: errors.New("some err"),
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectInvalidateOauthTokensCalled: true,
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
expectedErr: errExpiredAccessToken,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var (
|
||||
hasEntryCalled bool
|
||||
tryRefreshCalled bool
|
||||
invalidateTokensCalled bool
|
||||
revokeTokenCalled bool
|
||||
)
|
||||
|
||||
service := &oauthtokentest.MockOauthTokenService{
|
||||
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
hasEntryCalled = true
|
||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||
},
|
||||
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *models.UserAuth) error {
|
||||
invalidateTokensCalled = true
|
||||
return nil
|
||||
},
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr *models.UserAuth) error {
|
||||
tryRefreshCalled = true
|
||||
return tt.expectedTryRefreshErr
|
||||
},
|
||||
}
|
||||
|
||||
sessionService := &authtest.FakeUserAuthTokenService{
|
||||
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
|
||||
revokeTokenCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
sync := &OauthTokenSync{
|
||||
log: log.NewNopLogger(),
|
||||
service: service,
|
||||
sessionService: sessionService,
|
||||
}
|
||||
|
||||
err := sync.SyncOauthToken(context.Background(), tt.identity, nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
||||
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
||||
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
||||
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
53
pkg/services/oauthtoken/oauthtokentest/mock.go
Normal file
53
pkg/services/oauthtoken/oauthtokentest/mock.go
Normal file
@ -0,0 +1,53 @@
|
||||
package oauthtokentest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type MockOauthTokenService struct {
|
||||
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token
|
||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error)
|
||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr *models.UserAuth) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr *models.UserAuth) error
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
|
||||
if m.GetCurrentOauthTokenFunc != nil {
|
||||
return m.GetCurrentOauthTokenFunc(ctx, usr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
if m.IsOAuthPassThruEnabledFunc != nil {
|
||||
return m.IsOAuthPassThruEnabledFunc(ds)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
if m.HasOAuthEntryFunc != nil {
|
||||
return m.HasOAuthEntryFunc(ctx, usr)
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
if m.InvalidateOAuthTokensFunc != nil {
|
||||
return m.InvalidateOAuthTokensFunc(ctx, usr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||
if m.TryTokenRefreshFunc != nil {
|
||||
return m.TryTokenRefreshFunc(ctx, usr)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user