grafana/pkg/services/oauthtoken/oauth_token_test.go
Gabriel MABILLE 596e828150
Fix: Refresh token when id_token is expired (#79569)
* Fix: Refresh token when id_token is expired

* add id_token comparison

* Fix wire

* Use userID as cache key

* Apply suggestions from code review

---------

Co-authored-by: linoman <2051016+linoman@users.noreply.github.com>
Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>
2024-02-05 16:44:25 +01:00

611 lines
20 KiB
Go

package oauthtoken
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
func TestService_HasOAuthEntry(t *testing.T) {
testCases := []struct {
name string
user *user.SignedInUser
want *login.UserAuth
wantExist bool
wantErr bool
err error
getAuthInfoErr error
getAuthInfoUser login.UserAuth
}{
{
name: "returns false without an error in case user is nil",
user: nil,
want: nil,
wantExist: false,
wantErr: false,
},
{
name: "returns false and an error in case GetAuthInfo returns an error",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: true,
getAuthInfoErr: errors.New("error"),
},
{
name: "returns false without an error in case auth entry is not found",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoErr: user.ErrUserNotFound,
},
{
name: "returns false without an error in case the auth entry is not oauth",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoUser: login.UserAuth{AuthModule: "auth_saml"},
},
{
name: "returns true when the auth entry is found",
user: &user.SignedInUser{UserID: 1},
want: &login.UserAuth{AuthModule: login.GenericOAuthModule},
wantExist: true,
wantErr: false,
getAuthInfoUser: login.UserAuth{AuthModule: login.GenericOAuthModule},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
srv, authInfoStore, _ := setupOAuthTokenService(t)
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
authInfoStore.ExpectedError = tc.getAuthInfoErr
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user)
if tc.wantErr {
assert.Error(t, err)
}
if tc.want != nil {
assert.True(t, reflect.DeepEqual(tc.want, entry))
}
assert.Equal(t, tc.wantExist, exists)
})
}
}
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
t.Helper()
socialConnector := &socialtest.MockSocialConnector{}
socialService := &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
ExpectedAuthInfoProvider: &social.OAuthInfo{
UseRefreshToken: true,
},
}
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
}, authInfoStore, socialConnector
}
type FakeAuthInfoStore struct {
login.Store
ExpectedError error
ExpectedOAuth *login.UserAuth
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return f.ExpectedOAuth, f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
return f.ExpectedError
}
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func TestService_TryTokenRefresh(t *testing.T) {
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
identity identity.Requester
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
type testCase struct {
desc string
expectedErr error
setup func(env *environment)
}
tests := []testCase{
{
desc: "should skip sync when identity is nil",
},
{
desc: "should skip sync when identity is not a user",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "service-account:1"}
},
},
{
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:invalidIdentifierFormat"}
},
},
{
desc: "should skip token refresh since the token is still valid",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.identity = &authn.Identity{
AuthenticatedBy: login.GenericOAuthModule,
ID: "user:1234",
}
},
},
{
desc: "should skip token refresh if the expiration check has already been cached",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
},
},
{
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedError = errors.New("some error")
},
},
{
desc: "should skip token refresh if the user doesn't have an oauth entry",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.SAMLAuthModule,
}
},
},
{
desc: "should do token refresh if access token or id token have not expired yet",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
}
},
},
{
desc: "should skip token refresh when no oauth provider was found",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
},
},
{
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: false,
}
},
},
{
desc: "should skip token refresh when there is no refresh token",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
OAuthRefreshToken: "",
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
},
},
{
desc: "should do token refresh when the token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
{
desc: "should refresh token when the id token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
OAuthIdToken: EXPIRED_JWT,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
// token refresh
err := env.service.TryTokenRefresh(context.Background(), env.identity)
// test and validations
assert.ErrorIs(t, err, tt.expectedErr)
socialConnector.AssertExpectations(t)
})
}
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
tests := []struct {
name string
usr *login.UserAuth
expectedTokenRefreshFlag bool
expectedTokenDuration time.Duration
}{
{
name: "should not need token refresh when token has no expiration date",
usr: &login.UserAuth{},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
usr: &login.UserAuth{
OAuthIdToken: "invalid_jwt_format",
},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should flag token refresh with id token is expired",
usr: &login.UserAuth{
OAuthIdToken: EXPIRED_JWT,
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
{
name: "should flag token refresh when expiry date is zero",
usr: &login.UserAuth{
OAuthExpiry: time.Unix(0, 0),
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
})
}
}
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
timeNow := time.Now()
token := &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
}
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
tests := []struct {
desc string
expectedErr error
expectedToken *oauth2.Token
usr *login.UserAuth
setup func(env *environment)
}{
{
desc: "should find and retrieve token from cache",
usr: &login.UserAuth{
UserId: int64(1234),
OAuthAccessToken: "new_access_token",
OAuthExpiry: timeNow,
},
setup: func(env *environment) {
env.cache.Set("token-check-1234", token, 1*time.Minute)
},
expectedToken: &oauth2.Token{
AccessToken: "new_access_token",
Expiry: timeNow,
},
},
{
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.SAMLAuthModule,
},
expectedErr: ErrNotAnOAuthProvider,
},
{
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
},
expectedErr: ErrNoRefreshTokenFound,
},
{
desc: "should not refresh token if the token is not expired",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: timeNow.Add(time.Hour),
OAuthTokenType: "Bearer",
},
expectedToken: &oauth2.Token{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: timeNow.Add(time.Hour),
TokenType: "Bearer",
},
},
{
desc: "should update saved token if the user auth has new access/refresh tokens",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: "new_oauth_access_token",
OAuthRefreshToken: "new_refresh_token_found",
OAuthExpiry: timeNow,
},
expectedToken: &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
},
setup: func(env *environment) {
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
if tt.expectedToken != nil {
assert.Equal(t, tt.expectedToken, token)
}
assert.ErrorIs(t, tt.expectedErr, err)
socialConnector.AssertExpectations(t)
})
}
}