OAuth: Refactor OAuthToken service to make it easier to use the new external sessions (#96667)

* Refactor OAuthToken service

* introduce user.SessionAwareIdentityRequester

* replace login.UserAuth parameters with user.SessionAwareIdentityRequester

* Add nosec G101 to fake ID tokens

* Opt 2, min changes

* Revert a change to the current version
This commit is contained in:
Misi 2024-11-21 14:36:28 +01:00 committed by GitHub
parent afb4f6c0ce
commit 1061e4712f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 246 additions and 282 deletions

View File

@ -266,23 +266,21 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester) (*authn.Red
return nil, false return nil, false
} }
if err := c.oauthService.InvalidateOAuthTokens(ctx, &login.UserAuth{ ctxLogger := c.log.FromContext(ctx).New("userID", userID)
UserId: userID,
AuthId: user.GetAuthID(), if err := c.oauthService.InvalidateOAuthTokens(ctx, user); err != nil {
AuthModule: user.GetAuthenticatedBy(), ctxLogger.Error("Failed to invalidate tokens", "error", err)
}); err != nil {
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "id", user.GetID(), "error", err)
} }
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName) oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled { if !oauthCfg.Enabled {
c.log.FromContext(ctx).Debug("OAuth client is disabled") ctxLogger.Debug("OAuth client is disabled")
return nil, false return nil, false
} }
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg) redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
if redirectURL == "" { if redirectURL == "" {
c.log.FromContext(ctx).Debug("No signout redirect url configured") ctxLogger.Debug("No signout redirect url configured")
return nil, false return nil, false
} }

View File

@ -71,7 +71,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return error when state from ipd does not match stored state", desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-other-state"), URL: mustParseURL("http://grafana.com/?state=some-other-state"),
}, },
@ -83,7 +84,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return error when pkce is configured but the cookie is not present", desc: "should return error when pkce is configured but the cookie is not present",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
@ -95,7 +97,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return error when email is empty", desc: "should return error when email is empty",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
@ -110,7 +113,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return error when email is not allowed", desc: "should return error when email is not allowed",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
@ -144,7 +148,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return identity for valid request", desc: "should return identity for valid request",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
@ -182,7 +187,8 @@ func TestOAuth_Authenticate(t *testing.T) {
}, },
{ {
desc: "should return identity for valid request - and lookup user by email", desc: "should return identity for valid request - and lookup user by email",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{}, Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
@ -354,9 +360,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
var ( authCodeUrlCalled := false
authCodeUrlCalled = false
)
fakeSocialSvc := &socialtest.FakeSocialService{ fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg, ExpectedAuthInfoProvider: tt.oauthCfg,
@ -475,7 +479,7 @@ func TestOAuth_Logout(t *testing.T) {
"id_token": "some.id.token", "id_token": "some.id.token",
}) })
}, },
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error { InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester) error {
invalidateTokenCalled = true invalidateTokenCalled = true
return nil return nil
}, },

View File

@ -51,11 +51,12 @@ type OAuthTokenService interface {
IsOAuthPassThruEnabled(*datasources.DataSource) bool IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error) TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error)
InvalidateOAuthTokens(context.Context, *login.UserAuth) error InvalidateOAuthTokens(context.Context, identity.Requester) error
} }
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer, func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
serverLockService *serverlock.ServerLockService, tracer tracing.Tracer) *Service { serverLockService *serverlock.ServerLockService, tracer tracing.Tracer,
) *Service {
return &Service{ return &Service{
AuthInfoService: authInfoService, AuthInfoService: authInfoService,
Cfg: cfg, Cfg: cfg,
@ -71,6 +72,27 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken") ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken")
defer span.End() defer span.End()
ctxLogger := logger.FromContext(ctx)
if usr == nil || usr.IsNil() {
ctxLogger.Warn("Can only get OAuth tokens for existing users", "user", "nil")
// Not user, no token.
return nil
}
if !usr.IsIdentityType(claims.TypeUser) {
ctxLogger.Warn("Can only get OAuth tokens for users", "id", usr.GetID())
return nil
}
userID, err := usr.GetInternalID()
if err != nil {
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
return nil
}
ctxLogger = ctxLogger.New("userID", userID)
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr) authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
if !ok { if !ok {
return nil return nil
@ -84,7 +106,9 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
return nil return nil
} }
persistedToken, refreshNeeded := needTokenRefresh(authInfo) persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
refreshNeeded := needTokenRefresh(ctx, persistedToken)
if !refreshNeeded { if !refreshNeeded {
return persistedToken return persistedToken
} }
@ -226,14 +250,16 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) (
return return
} }
storedToken, needRefresh := needTokenRefresh(authInfo) persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
needRefresh := needTokenRefresh(ctx, persistedToken)
if !needRefresh { if !needRefresh {
// Set the token which is returned by the outer function in case there's no need to refresh the token // Set the token which is returned by the outer function in case there's no need to refresh the token
newToken = storedToken newToken = persistedToken
return return
} }
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, authInfo) newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr)
}, retryOpt) }, retryOpt)
if lockErr != nil { if lockErr != nil {
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr) ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
@ -280,11 +306,17 @@ func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
} }
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero // InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.UserAuth) error { func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error {
userID, err := usr.GetInternalID()
if err != nil {
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
return err
}
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{ return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
UserId: authInfo.UserId, UserId: userID,
AuthModule: authInfo.AuthModule, AuthModule: usr.GetAuthenticatedBy(),
AuthId: authInfo.AuthId, AuthId: usr.GetAuthID(),
OAuthToken: &oauth2.Token{ OAuthToken: &oauth2.Token{
AccessToken: "", AccessToken: "",
RefreshToken: "", RefreshToken: "",
@ -293,23 +325,31 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.Use
}) })
} }
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login.UserAuth) (*oauth2.Token, error) { func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester) (*oauth2.Token, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken", ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken")
trace.WithAttributes(attribute.Int64("userID", authInfo.UserId)))
defer span.End() defer span.End()
ctxLogger := logger.FromContext(ctx).New("userID", authInfo.UserId) userID, err := usr.GetInternalID()
if err != nil {
if err := checkOAuthRefreshToken(authInfo); err != nil { logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
return nil, err return nil, err
} }
persistedToken, refreshNeeded := needTokenRefresh(authInfo) span.SetAttributes(attribute.Int64("userID", userID))
ctxLogger := logger.FromContext(ctx).New("userID", userID)
if persistedToken.RefreshToken == "" {
ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy())
return nil, ErrNoRefreshTokenFound
}
refreshNeeded := needTokenRefresh(ctx, persistedToken)
if !refreshNeeded { if !refreshNeeded {
return persistedToken, nil return persistedToken, nil
} }
authProvider := authInfo.AuthModule authProvider := usr.GetAuthenticatedBy()
connect, err := o.SocialService.GetConnector(authProvider) connect, err := o.SocialService.GetConnector(authProvider)
if err != nil { if err != nil {
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err) ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
@ -331,11 +371,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
if err != nil { if err != nil {
ctxLogger.Error("Failed to retrieve oauth access token", ctxLogger.Error("Failed to retrieve oauth access token",
"provider", authInfo.AuthModule, "userId", authInfo.UserId, "error", err) "provider", usr.GetAuthenticatedBy(), "error", err)
// token refresh failed, invalidate the old token // token refresh failed, invalidate the old token
if err := o.InvalidateOAuthTokens(ctx, authInfo); err != nil { if err := o.InvalidateOAuthTokens(ctx, usr); err != nil {
ctxLogger.Warn("Failed to invalidate OAuth tokens", "id", authInfo.Id, "error", err) ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err)
} }
return nil, err return nil, err
@ -344,15 +384,15 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
// If the tokens are not the same, update the entry in the DB // If the tokens are not the same, update the entry in the DB
if !tokensEq(persistedToken, token) { if !tokensEq(persistedToken, token) {
updateAuthCommand := &login.UpdateAuthInfoCommand{ updateAuthCommand := &login.UpdateAuthInfoCommand{
UserId: authInfo.UserId, UserId: userID,
AuthModule: authInfo.AuthModule, AuthModule: usr.GetAuthenticatedBy(),
AuthId: authInfo.AuthId, AuthId: usr.GetAuthID(),
OAuthToken: token, OAuthToken: token,
} }
if o.Cfg.Env == setting.Dev { if o.Cfg.Env == setting.Dev {
ctxLogger.Debug("Oauth got token", ctxLogger.Debug("Oauth got token",
"auth_module", authInfo.AuthModule, "auth_module", usr.GetAuthID(),
"expiry", fmt.Sprintf("%v", token.Expiry), "expiry", fmt.Sprintf("%v", token.Expiry),
"access_token", fmt.Sprintf("%v", token.AccessToken), "access_token", fmt.Sprintf("%v", token.AccessToken),
"refresh_token", fmt.Sprintf("%v", token.RefreshToken), "refresh_token", fmt.Sprintf("%v", token.RefreshToken),
@ -360,7 +400,7 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
} }
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
ctxLogger.Error("Failed to update auth info during token refresh", "userId", authInfo.UserId, "error", err) ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err)
return token, err return token, err
} }
ctxLogger.Debug("Updated oauth info for user") ctxLogger.Debug("Updated oauth info for user")
@ -401,14 +441,14 @@ func tokensEq(t1, t2 *oauth2.Token) bool {
t1IdToken == t2IdToken t1IdToken == t2IdToken
} }
func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) { func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool {
var hasAccessTokenExpired, hasIdTokenExpired bool var hasAccessTokenExpired, hasIdTokenExpired bool
persistedToken := buildOAuthTokenFromAuthInfo(authInfo) ctxLogger := logger.FromContext(ctx)
idTokenExp, err := GetIDTokenExpiry(persistedToken) idTokenExp, err := GetIDTokenExpiry(persistedToken)
if err != nil { if err != nil {
logger.Warn("Could not get ID Token expiry", "error", err) ctxLogger.Warn("Could not get ID Token expiry", "error", err)
} }
if !persistedToken.Expiry.IsZero() { if !persistedToken.Expiry.IsZero() {
_, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry) _, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
@ -417,14 +457,14 @@ func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) {
_, hasIdTokenExpired = getExpiryWithSkew(idTokenExp) _, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
} }
if !hasAccessTokenExpired && !hasIdTokenExpired { if !hasAccessTokenExpired && !hasIdTokenExpired {
logger.Debug("Neither access nor id token have expired yet", "userID", authInfo.UserId) ctxLogger.Debug("Neither access nor id token have expired yet")
return persistedToken, false return false
} }
if hasIdTokenExpired { if hasIdTokenExpired {
// Force refreshing token when id token is expired // Force refreshing token when id token is expired
persistedToken.AccessToken = "" persistedToken.AccessToken = ""
} }
return persistedToken, true return true
} }
// GetIDTokenExpiry extracts the expiry time from the ID token // GetIDTokenExpiry extracts the expiry time from the ID token

View File

@ -31,7 +31,9 @@ import (
"github.com/grafana/grafana/pkg/tests/testsuite" "github.com/grafana/grafana/pkg/tests/testsuite"
) )
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U" const EXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjoxNjAwMDAwMDAwLCJpYXQiOjE2MDAwMDAwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
const UNEXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjo0ODg1NjA4MDAwLCJpYXQiOjE2ODU2MDgwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
testsuite.Run(m) testsuite.Run(m)
@ -162,19 +164,44 @@ func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.Delet
} }
func TestService_TryTokenRefresh(t *testing.T) { func TestService_TryTokenRefresh(t *testing.T) {
unexpiredToken := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{
"id_token": UNEXPIRED_ID_TOKEN,
})
expiredToken := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
type environment struct { type environment struct {
authInfoService *authinfotest.FakeService authInfoService *authinfotest.FakeService
serverLock *serverlock.ServerLockService serverLock *serverlock.ServerLockService
identity identity.Requester
socialConnector *socialtest.MockSocialConnector socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService socialService *socialtest.FakeSocialService
service *Service service *Service
} }
type testCase struct { type testCase struct {
desc string desc string
expectedErr error identity identity.Requester
setup func(env *environment) setup func(env *environment)
expectedToken *oauth2.Token
expectedErr error
}
userIdentity := &authn.Identity{
AuthenticatedBy: login.GenericOAuthModule,
ID: "1234",
Type: claims.TypeUser,
} }
tests := []testCase{ tests := []testCase{
@ -183,83 +210,43 @@ func TestService_TryTokenRefresh(t *testing.T) {
}, },
{ {
desc: "should skip sync when identity is not a user", desc: "should skip sync when identity is not a user",
setup: func(env *environment) { identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount},
env.identity = &authn.Identity{ID: "1", Type: claims.TypeServiceAccount}
},
}, },
{ {
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID", desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
setup: func(env *environment) { identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser},
env.identity = &authn.Identity{ID: "invalid", Type: claims.TypeUser}
},
},
{
desc: "should skip token refresh since the token is still valid",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.identity = &authn.Identity{
AuthenticatedBy: login.GenericOAuthModule,
ID: "1234",
Type: claims.TypeUser,
}
},
}, },
{ {
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned", desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedError = errors.New("some error") env.authInfoService.ExpectedError = errors.New("some error")
}, },
}, },
{ {
desc: "should skip token refresh if the user doesn't have an oauth entry", desc: "should skip token refresh if the user doesn't have an oauth entry",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{ env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.SAMLAuthModule, AuthModule: login.SAMLAuthModule,
} }
}, },
}, },
{
desc: "should do token refresh if access token or id token have not expired yet",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
}
},
},
{ {
desc: "should skip token refresh when no oauth provider was found", desc: "should skip token refresh when no oauth provider was found",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{ env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule, AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
} }
}, },
}, },
{ {
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{ env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule, AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
} }
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: false, UseRefreshToken: false,
@ -267,29 +254,66 @@ func TestService_TryTokenRefresh(t *testing.T) {
}, },
}, },
{ {
desc: "should skip token refresh when there is no refresh token", desc: "should skip token refresh when the token is still valid and no id token is present",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{ env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule, AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT, OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
},
expectedToken: unexpiredToken,
},
{
desc: "should not refresh the tokens if access token or id token have not expired yet",
identity: userIdentity,
setup: func(env *environment) {
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: UNEXPIRED_ID_TOKEN,
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
},
expectedToken: unexpiredTokenWithIDToken,
},
{
desc: "should skip token refresh when there is no refresh token",
identity: userIdentity,
setup: func(env *environment) {
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
OAuthRefreshToken: "", OAuthRefreshToken: "",
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
} }
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true, UseRefreshToken: true,
} }
}, },
expectedToken: &oauth2.Token{
AccessToken: unexpiredTokenWithIDToken.AccessToken,
RefreshToken: "",
Expiry: unexpiredTokenWithIDToken.Expiry,
},
}, },
{ {
desc: "should do token refresh when the token is expired", desc: "should do token refresh when the token is expired",
identity: userIdentity,
setup: func(env *environment) { setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true, UseRefreshToken: true,
} }
@ -297,24 +321,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
AuthModule: login.GenericOAuthModule, AuthModule: login.GenericOAuthModule,
AuthId: "subject", AuthId: "subject",
UserId: 1, UserId: 1,
OAuthAccessToken: token.AccessToken, OAuthAccessToken: expiredToken.AccessToken,
OAuthRefreshToken: token.RefreshToken, OAuthRefreshToken: expiredToken.RefreshToken,
OAuthExpiry: token.Expiry, OAuthExpiry: expiredToken.Expiry,
OAuthTokenType: token.TokenType, OAuthTokenType: expiredToken.TokenType,
OAuthIdToken: EXPIRED_ID_TOKEN,
} }
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
}, },
expectedToken: unexpiredTokenWithIDToken,
}, },
{ {
desc: "should refresh token when the id token is expired", desc: "should refresh token when the id token is expired",
identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule},
setup: func(env *environment) { setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true, UseRefreshToken: true,
} }
@ -322,19 +342,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
AuthModule: login.GenericOAuthModule, AuthModule: login.GenericOAuthModule,
AuthId: "subject", AuthId: "subject",
UserId: 1, UserId: 1,
OAuthAccessToken: token.AccessToken, OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
OAuthRefreshToken: token.RefreshToken, OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
OAuthExpiry: token.Expiry, OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
OAuthTokenType: token.TokenType, OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
OAuthIdToken: EXPIRED_JWT, OAuthIdToken: EXPIRED_ID_TOKEN,
} }
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
}, },
expectedToken: unexpiredTokenWithIDToken,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{} socialConnector := socialtest.NewMockSocialConnector(t)
store := db.InitTestDB(t) store := db.InitTestDB(t)
@ -361,11 +382,27 @@ func TestService_TryTokenRefresh(t *testing.T) {
) )
// token refresh // token refresh
_, err := env.service.TryTokenRefresh(context.Background(), env.identity) actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity)
// test and validations if tt.expectedErr != nil {
assert.ErrorIs(t, err, tt.expectedErr) assert.ErrorIs(t, err, tt.expectedErr)
socialConnector.AssertExpectations(t) return
}
if tt.expectedToken == nil {
assert.Nil(t, actualToken)
return
}
assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken)
assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken)
assert.Equal(t, tt.expectedToken.Expiry, actualToken.Expiry)
assert.Equal(t, tt.expectedToken.TokenType, actualToken.TokenType)
if tt.expectedToken.Extra("id_token") != nil {
assert.Equal(t, tt.expectedToken.Extra("id_token").(string), actualToken.Extra("id_token").(string))
} else {
assert.Nil(t, actualToken.Extra("id_token"))
}
}) })
} }
} }
@ -392,7 +429,7 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
{ {
name: "should flag token refresh with id token is expired", name: "should flag token refresh with id token is expired",
usr: &login.UserAuth{ usr: &login.UserAuth{
OAuthIdToken: EXPIRED_JWT, OAuthIdToken: EXPIRED_ID_TOKEN,
}, },
expectedTokenRefreshFlag: true, expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second, expectedTokenDuration: time.Second,
@ -408,125 +445,10 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
token, needsTokenRefresh := needTokenRefresh(tt.usr) token := buildOAuthTokenFromAuthInfo(tt.usr)
needsTokenRefresh := needTokenRefresh(context.Background(), token)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh) assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
}) })
} }
} }
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
timeNow := time.Now()
token := &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
}
type environment struct {
authInfoService *authinfotest.FakeService
serverLock *serverlock.ServerLockService
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
tests := []struct {
desc string
expectedErr error
expectedToken *oauth2.Token
usr *login.UserAuth
setup func(env *environment)
}{
{
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.SAMLAuthModule,
},
expectedErr: ErrNotAnOAuthProvider,
},
{
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
},
expectedErr: ErrNoRefreshTokenFound,
},
{
desc: "should not refresh token if the token is not expired",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: timeNow.Add(time.Hour),
OAuthTokenType: "Bearer",
},
expectedToken: &oauth2.Token{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: timeNow.Add(time.Hour),
TokenType: "Bearer",
},
},
{
desc: "should update saved token if the user auth has new access/refresh tokens",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: "new_oauth_access_token",
OAuthRefreshToken: "new_refresh_token_found",
OAuthExpiry: timeNow,
},
expectedToken: &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
},
setup: func(env *environment) {
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
store := db.InitTestDB(t)
env := environment{
authInfoService: &authinfotest.FakeService{},
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = ProvideService(
env.socialService,
env.authInfoService,
setting.NewCfg(),
prometheus.NewRegistry(),
env.serverLock,
tracing.InitializeTracerForTest(),
)
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
if tt.expectedToken != nil {
assert.Equal(t, tt.expectedToken, token)
}
assert.ErrorIs(t, tt.expectedErr, err)
socialConnector.AssertExpectations(t)
})
}
}

View File

@ -14,7 +14,7 @@ type MockOauthTokenService struct {
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error) TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error)
} }
@ -39,7 +39,7 @@ func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity.
return nil, false, nil return nil, false, nil
} }
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth) error { func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error {
if m.InvalidateOAuthTokensFunc != nil { if m.InvalidateOAuthTokensFunc != nil {
return m.InvalidateOAuthTokensFunc(ctx, usr) return m.InvalidateOAuthTokensFunc(ctx, usr)
} }

View File

@ -37,6 +37,6 @@ func (s *Service) TryTokenRefresh(context.Context, identity.Requester) (*oauth2.
return s.Token, nil return s.Token, nil
} }
func (s *Service) InvalidateOAuthTokens(context.Context, *login.UserAuth) error { func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester) error {
return nil return nil
} }