mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
OAuth: Refactor OAuthToken service to make it easier to use the new external sessions (#96667)
* Refactor OAuthToken service * introduce user.SessionAwareIdentityRequester * replace login.UserAuth parameters with user.SessionAwareIdentityRequester * Add nosec G101 to fake ID tokens * Opt 2, min changes * Revert a change to the current version
This commit is contained in:
parent
afb4f6c0ce
commit
1061e4712f
@ -266,23 +266,21 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester) (*authn.Red
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.oauthService.InvalidateOAuthTokens(ctx, &login.UserAuth{
|
ctxLogger := c.log.FromContext(ctx).New("userID", userID)
|
||||||
UserId: userID,
|
|
||||||
AuthId: user.GetAuthID(),
|
if err := c.oauthService.InvalidateOAuthTokens(ctx, user); err != nil {
|
||||||
AuthModule: user.GetAuthenticatedBy(),
|
ctxLogger.Error("Failed to invalidate tokens", "error", err)
|
||||||
}); err != nil {
|
|
||||||
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "id", user.GetID(), "error", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
|
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
|
||||||
if !oauthCfg.Enabled {
|
if !oauthCfg.Enabled {
|
||||||
c.log.FromContext(ctx).Debug("OAuth client is disabled")
|
ctxLogger.Debug("OAuth client is disabled")
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
|
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
|
||||||
if redirectURL == "" {
|
if redirectURL == "" {
|
||||||
c.log.FromContext(ctx).Debug("No signout redirect url configured")
|
ctxLogger.Debug("No signout redirect url configured")
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,10 +71,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when state from ipd does not match stored state",
|
desc: "should return error when state from ipd does not match stored state",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
@ -83,10 +84,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when pkce is configured but the cookie is not present",
|
desc: "should return error when pkce is configured but the cookie is not present",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
@ -95,10 +97,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when email is empty",
|
desc: "should return error when email is empty",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
@ -110,10 +113,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when email is not allowed",
|
desc: "should return error when email is not allowed",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
@ -144,10 +148,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return identity for valid request",
|
desc: "should return identity for valid request",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
@ -182,10 +187,11 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return identity for valid request - and lookup user by email",
|
desc: "should return identity for valid request - and lookup user by email",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{
|
||||||
Header: map[string][]string{},
|
HTTPRequest: &http.Request{
|
||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
Header: map[string][]string{},
|
||||||
},
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
allowInsecureTakeover: true,
|
allowInsecureTakeover: true,
|
||||||
@ -354,9 +360,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.desc, func(t *testing.T) {
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
var (
|
authCodeUrlCalled := false
|
||||||
authCodeUrlCalled = false
|
|
||||||
)
|
|
||||||
|
|
||||||
fakeSocialSvc := &socialtest.FakeSocialService{
|
fakeSocialSvc := &socialtest.FakeSocialService{
|
||||||
ExpectedAuthInfoProvider: tt.oauthCfg,
|
ExpectedAuthInfoProvider: tt.oauthCfg,
|
||||||
@ -475,7 +479,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
"id_token": "some.id.token",
|
"id_token": "some.id.token",
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
|
InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester) error {
|
||||||
invalidateTokenCalled = true
|
invalidateTokenCalled = true
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -51,11 +51,12 @@ type OAuthTokenService interface {
|
|||||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||||
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
||||||
TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error)
|
TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error)
|
||||||
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
InvalidateOAuthTokens(context.Context, identity.Requester) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
|
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
|
||||||
serverLockService *serverlock.ServerLockService, tracer tracing.Tracer) *Service {
|
serverLockService *serverlock.ServerLockService, tracer tracing.Tracer,
|
||||||
|
) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
AuthInfoService: authInfoService,
|
AuthInfoService: authInfoService,
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
@ -71,6 +72,27 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
|
|||||||
ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken")
|
ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
ctxLogger := logger.FromContext(ctx)
|
||||||
|
|
||||||
|
if usr == nil || usr.IsNil() {
|
||||||
|
ctxLogger.Warn("Can only get OAuth tokens for existing users", "user", "nil")
|
||||||
|
// Not user, no token.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !usr.IsIdentityType(claims.TypeUser) {
|
||||||
|
ctxLogger.Warn("Can only get OAuth tokens for users", "id", usr.GetID())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := usr.GetInternalID()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxLogger = ctxLogger.New("userID", userID)
|
||||||
|
|
||||||
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
|
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@ -84,7 +106,9 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
persistedToken, refreshNeeded := needTokenRefresh(authInfo)
|
persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
|
||||||
|
|
||||||
|
refreshNeeded := needTokenRefresh(ctx, persistedToken)
|
||||||
if !refreshNeeded {
|
if !refreshNeeded {
|
||||||
return persistedToken
|
return persistedToken
|
||||||
}
|
}
|
||||||
@ -226,14 +250,16 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) (
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
storedToken, needRefresh := needTokenRefresh(authInfo)
|
persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
|
||||||
|
|
||||||
|
needRefresh := needTokenRefresh(ctx, persistedToken)
|
||||||
if !needRefresh {
|
if !needRefresh {
|
||||||
// Set the token which is returned by the outer function in case there's no need to refresh the token
|
// Set the token which is returned by the outer function in case there's no need to refresh the token
|
||||||
newToken = storedToken
|
newToken = persistedToken
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, authInfo)
|
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr)
|
||||||
}, retryOpt)
|
}, retryOpt)
|
||||||
if lockErr != nil {
|
if lockErr != nil {
|
||||||
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
|
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
|
||||||
@ -280,11 +306,17 @@ func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
||||||
func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.UserAuth) error {
|
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error {
|
||||||
|
userID, err := usr.GetInternalID()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
|
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
|
||||||
UserId: authInfo.UserId,
|
UserId: userID,
|
||||||
AuthModule: authInfo.AuthModule,
|
AuthModule: usr.GetAuthenticatedBy(),
|
||||||
AuthId: authInfo.AuthId,
|
AuthId: usr.GetAuthID(),
|
||||||
OAuthToken: &oauth2.Token{
|
OAuthToken: &oauth2.Token{
|
||||||
AccessToken: "",
|
AccessToken: "",
|
||||||
RefreshToken: "",
|
RefreshToken: "",
|
||||||
@ -293,23 +325,31 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.Use
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login.UserAuth) (*oauth2.Token, error) {
|
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester) (*oauth2.Token, error) {
|
||||||
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken",
|
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken")
|
||||||
trace.WithAttributes(attribute.Int64("userID", authInfo.UserId)))
|
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
ctxLogger := logger.FromContext(ctx).New("userID", authInfo.UserId)
|
userID, err := usr.GetInternalID()
|
||||||
|
if err != nil {
|
||||||
if err := checkOAuthRefreshToken(authInfo); err != nil {
|
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
persistedToken, refreshNeeded := needTokenRefresh(authInfo)
|
span.SetAttributes(attribute.Int64("userID", userID))
|
||||||
|
|
||||||
|
ctxLogger := logger.FromContext(ctx).New("userID", userID)
|
||||||
|
|
||||||
|
if persistedToken.RefreshToken == "" {
|
||||||
|
ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy())
|
||||||
|
return nil, ErrNoRefreshTokenFound
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshNeeded := needTokenRefresh(ctx, persistedToken)
|
||||||
if !refreshNeeded {
|
if !refreshNeeded {
|
||||||
return persistedToken, nil
|
return persistedToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
authProvider := authInfo.AuthModule
|
authProvider := usr.GetAuthenticatedBy()
|
||||||
connect, err := o.SocialService.GetConnector(authProvider)
|
connect, err := o.SocialService.GetConnector(authProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
|
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
|
||||||
@ -331,11 +371,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctxLogger.Error("Failed to retrieve oauth access token",
|
ctxLogger.Error("Failed to retrieve oauth access token",
|
||||||
"provider", authInfo.AuthModule, "userId", authInfo.UserId, "error", err)
|
"provider", usr.GetAuthenticatedBy(), "error", err)
|
||||||
|
|
||||||
// token refresh failed, invalidate the old token
|
// token refresh failed, invalidate the old token
|
||||||
if err := o.InvalidateOAuthTokens(ctx, authInfo); err != nil {
|
if err := o.InvalidateOAuthTokens(ctx, usr); err != nil {
|
||||||
ctxLogger.Warn("Failed to invalidate OAuth tokens", "id", authInfo.Id, "error", err)
|
ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -344,15 +384,15 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
|
|||||||
// If the tokens are not the same, update the entry in the DB
|
// If the tokens are not the same, update the entry in the DB
|
||||||
if !tokensEq(persistedToken, token) {
|
if !tokensEq(persistedToken, token) {
|
||||||
updateAuthCommand := &login.UpdateAuthInfoCommand{
|
updateAuthCommand := &login.UpdateAuthInfoCommand{
|
||||||
UserId: authInfo.UserId,
|
UserId: userID,
|
||||||
AuthModule: authInfo.AuthModule,
|
AuthModule: usr.GetAuthenticatedBy(),
|
||||||
AuthId: authInfo.AuthId,
|
AuthId: usr.GetAuthID(),
|
||||||
OAuthToken: token,
|
OAuthToken: token,
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.Cfg.Env == setting.Dev {
|
if o.Cfg.Env == setting.Dev {
|
||||||
ctxLogger.Debug("Oauth got token",
|
ctxLogger.Debug("Oauth got token",
|
||||||
"auth_module", authInfo.AuthModule,
|
"auth_module", usr.GetAuthID(),
|
||||||
"expiry", fmt.Sprintf("%v", token.Expiry),
|
"expiry", fmt.Sprintf("%v", token.Expiry),
|
||||||
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
||||||
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
||||||
@ -360,7 +400,7 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
||||||
ctxLogger.Error("Failed to update auth info during token refresh", "userId", authInfo.UserId, "error", err)
|
ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err)
|
||||||
return token, err
|
return token, err
|
||||||
}
|
}
|
||||||
ctxLogger.Debug("Updated oauth info for user")
|
ctxLogger.Debug("Updated oauth info for user")
|
||||||
@ -401,14 +441,14 @@ func tokensEq(t1, t2 *oauth2.Token) bool {
|
|||||||
t1IdToken == t2IdToken
|
t1IdToken == t2IdToken
|
||||||
}
|
}
|
||||||
|
|
||||||
func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) {
|
func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool {
|
||||||
var hasAccessTokenExpired, hasIdTokenExpired bool
|
var hasAccessTokenExpired, hasIdTokenExpired bool
|
||||||
|
|
||||||
persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
|
ctxLogger := logger.FromContext(ctx)
|
||||||
|
|
||||||
idTokenExp, err := GetIDTokenExpiry(persistedToken)
|
idTokenExp, err := GetIDTokenExpiry(persistedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Could not get ID Token expiry", "error", err)
|
ctxLogger.Warn("Could not get ID Token expiry", "error", err)
|
||||||
}
|
}
|
||||||
if !persistedToken.Expiry.IsZero() {
|
if !persistedToken.Expiry.IsZero() {
|
||||||
_, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
|
_, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
|
||||||
@ -417,14 +457,14 @@ func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) {
|
|||||||
_, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
|
_, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
|
||||||
}
|
}
|
||||||
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
if !hasAccessTokenExpired && !hasIdTokenExpired {
|
||||||
logger.Debug("Neither access nor id token have expired yet", "userID", authInfo.UserId)
|
ctxLogger.Debug("Neither access nor id token have expired yet")
|
||||||
return persistedToken, false
|
return false
|
||||||
}
|
}
|
||||||
if hasIdTokenExpired {
|
if hasIdTokenExpired {
|
||||||
// Force refreshing token when id token is expired
|
// Force refreshing token when id token is expired
|
||||||
persistedToken.AccessToken = ""
|
persistedToken.AccessToken = ""
|
||||||
}
|
}
|
||||||
return persistedToken, true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIDTokenExpiry extracts the expiry time from the ID token
|
// GetIDTokenExpiry extracts the expiry time from the ID token
|
||||||
|
@ -31,7 +31,9 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/tests/testsuite"
|
"github.com/grafana/grafana/pkg/tests/testsuite"
|
||||||
)
|
)
|
||||||
|
|
||||||
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
|
const EXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjoxNjAwMDAwMDAwLCJpYXQiOjE2MDAwMDAwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
|
||||||
|
|
||||||
|
const UNEXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjo0ODg1NjA4MDAwLCJpYXQiOjE2ODU2MDgwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
testsuite.Run(m)
|
testsuite.Run(m)
|
||||||
@ -162,19 +164,44 @@ func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.Delet
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestService_TryTokenRefresh(t *testing.T) {
|
func TestService_TryTokenRefresh(t *testing.T) {
|
||||||
|
unexpiredToken := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{
|
||||||
|
"id_token": UNEXPIRED_ID_TOKEN,
|
||||||
|
})
|
||||||
|
|
||||||
|
expiredToken := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
|
||||||
type environment struct {
|
type environment struct {
|
||||||
authInfoService *authinfotest.FakeService
|
authInfoService *authinfotest.FakeService
|
||||||
serverLock *serverlock.ServerLockService
|
serverLock *serverlock.ServerLockService
|
||||||
identity identity.Requester
|
|
||||||
socialConnector *socialtest.MockSocialConnector
|
socialConnector *socialtest.MockSocialConnector
|
||||||
socialService *socialtest.FakeSocialService
|
socialService *socialtest.FakeSocialService
|
||||||
|
|
||||||
service *Service
|
service *Service
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
desc string
|
desc string
|
||||||
expectedErr error
|
identity identity.Requester
|
||||||
setup func(env *environment)
|
setup func(env *environment)
|
||||||
|
expectedToken *oauth2.Token
|
||||||
|
expectedErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
userIdentity := &authn.Identity{
|
||||||
|
AuthenticatedBy: login.GenericOAuthModule,
|
||||||
|
ID: "1234",
|
||||||
|
Type: claims.TypeUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
@ -182,84 +209,44 @@ func TestService_TryTokenRefresh(t *testing.T) {
|
|||||||
desc: "should skip sync when identity is nil",
|
desc: "should skip sync when identity is nil",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip sync when identity is not a user",
|
desc: "should skip sync when identity is not a user",
|
||||||
setup: func(env *environment) {
|
identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount},
|
||||||
env.identity = &authn.Identity{ID: "1", Type: claims.TypeServiceAccount}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
||||||
setup: func(env *environment) {
|
identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser},
|
||||||
env.identity = &authn.Identity{ID: "invalid", Type: claims.TypeUser}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip token refresh since the token is still valid",
|
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "testrefresh",
|
|
||||||
Expiry: time.Now().Add(time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
||||||
AuthModule: login.GenericOAuthModule,
|
|
||||||
OAuthAccessToken: token.AccessToken,
|
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
|
||||||
OAuthExpiry: token.Expiry,
|
|
||||||
OAuthTokenType: token.TokenType,
|
|
||||||
}
|
|
||||||
|
|
||||||
env.identity = &authn.Identity{
|
|
||||||
AuthenticatedBy: login.GenericOAuthModule,
|
|
||||||
ID: "1234",
|
|
||||||
Type: claims.TypeUser,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
|
||||||
setup: func(env *environment) {
|
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedError = errors.New("some error")
|
env.authInfoService.ExpectedError = errors.New("some error")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip token refresh if the user doesn't have an oauth entry",
|
desc: "should skip token refresh if the user doesn't have an oauth entry",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
AuthModule: login.SAMLAuthModule,
|
AuthModule: login.SAMLAuthModule,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should do token refresh if access token or id token have not expired yet",
|
desc: "should skip token refresh when no oauth provider was found",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
AuthModule: login.GenericOAuthModule,
|
AuthModule: login.GenericOAuthModule,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip token refresh when no oauth provider was found",
|
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
AuthModule: login.GenericOAuthModule,
|
AuthModule: login.GenericOAuthModule,
|
||||||
OAuthIdToken: EXPIRED_JWT,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
|
||||||
setup: func(env *environment) {
|
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
||||||
AuthModule: login.GenericOAuthModule,
|
|
||||||
OAuthIdToken: EXPIRED_JWT,
|
|
||||||
}
|
}
|
||||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
UseRefreshToken: false,
|
UseRefreshToken: false,
|
||||||
@ -267,29 +254,66 @@ func TestService_TryTokenRefresh(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should skip token refresh when there is no refresh token",
|
desc: "should skip token refresh when the token is still valid and no id token is present",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
|
|
||||||
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
AuthModule: login.GenericOAuthModule,
|
AuthModule: login.GenericOAuthModule,
|
||||||
OAuthIdToken: EXPIRED_JWT,
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
||||||
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
||||||
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
||||||
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedToken: unexpiredToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should not refresh the tokens if access token or id token have not expired yet",
|
||||||
|
identity: userIdentity,
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthIdToken: UNEXPIRED_ID_TOKEN,
|
||||||
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
||||||
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
||||||
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
||||||
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
|
UseRefreshToken: true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedToken: unexpiredTokenWithIDToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip token refresh when there is no refresh token",
|
||||||
|
identity: userIdentity,
|
||||||
|
setup: func(env *environment) {
|
||||||
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
||||||
|
AuthModule: login.GenericOAuthModule,
|
||||||
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
||||||
OAuthRefreshToken: "",
|
OAuthRefreshToken: "",
|
||||||
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
||||||
}
|
}
|
||||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
UseRefreshToken: true,
|
UseRefreshToken: true,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
expectedToken: &oauth2.Token{
|
||||||
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
||||||
|
RefreshToken: "",
|
||||||
|
Expiry: unexpiredTokenWithIDToken.Expiry,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should do token refresh when the token is expired",
|
desc: "should do token refresh when the token is expired",
|
||||||
|
identity: userIdentity,
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "testrefresh",
|
|
||||||
Expiry: time.Now().Add(-time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
|
|
||||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
UseRefreshToken: true,
|
UseRefreshToken: true,
|
||||||
}
|
}
|
||||||
@ -297,24 +321,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
|
|||||||
AuthModule: login.GenericOAuthModule,
|
AuthModule: login.GenericOAuthModule,
|
||||||
AuthId: "subject",
|
AuthId: "subject",
|
||||||
UserId: 1,
|
UserId: 1,
|
||||||
OAuthAccessToken: token.AccessToken,
|
OAuthAccessToken: expiredToken.AccessToken,
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
OAuthRefreshToken: expiredToken.RefreshToken,
|
||||||
OAuthExpiry: token.Expiry,
|
OAuthExpiry: expiredToken.Expiry,
|
||||||
OAuthTokenType: token.TokenType,
|
OAuthTokenType: expiredToken.TokenType,
|
||||||
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
||||||
}
|
}
|
||||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
||||||
},
|
},
|
||||||
|
expectedToken: unexpiredTokenWithIDToken,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should refresh token when the id token is expired",
|
desc: "should refresh token when the id token is expired",
|
||||||
|
identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule},
|
||||||
setup: func(env *environment) {
|
setup: func(env *environment) {
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "testaccess",
|
|
||||||
RefreshToken: "testrefresh",
|
|
||||||
Expiry: time.Now().Add(time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
|
|
||||||
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
||||||
UseRefreshToken: true,
|
UseRefreshToken: true,
|
||||||
}
|
}
|
||||||
@ -322,19 +342,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
|
|||||||
AuthModule: login.GenericOAuthModule,
|
AuthModule: login.GenericOAuthModule,
|
||||||
AuthId: "subject",
|
AuthId: "subject",
|
||||||
UserId: 1,
|
UserId: 1,
|
||||||
OAuthAccessToken: token.AccessToken,
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
||||||
OAuthExpiry: token.Expiry,
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
||||||
OAuthTokenType: token.TokenType,
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
||||||
OAuthIdToken: EXPIRED_JWT,
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
||||||
}
|
}
|
||||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
||||||
},
|
},
|
||||||
|
expectedToken: unexpiredTokenWithIDToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.desc, func(t *testing.T) {
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
socialConnector := &socialtest.MockSocialConnector{}
|
socialConnector := socialtest.NewMockSocialConnector(t)
|
||||||
|
|
||||||
store := db.InitTestDB(t)
|
store := db.InitTestDB(t)
|
||||||
|
|
||||||
@ -361,11 +382,27 @@ func TestService_TryTokenRefresh(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// token refresh
|
// token refresh
|
||||||
_, err := env.service.TryTokenRefresh(context.Background(), env.identity)
|
actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity)
|
||||||
|
|
||||||
// test and validations
|
if tt.expectedErr != nil {
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
socialConnector.AssertExpectations(t)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedToken == nil {
|
||||||
|
assert.Nil(t, actualToken)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken)
|
||||||
|
assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken)
|
||||||
|
assert.Equal(t, tt.expectedToken.Expiry, actualToken.Expiry)
|
||||||
|
assert.Equal(t, tt.expectedToken.TokenType, actualToken.TokenType)
|
||||||
|
if tt.expectedToken.Extra("id_token") != nil {
|
||||||
|
assert.Equal(t, tt.expectedToken.Extra("id_token").(string), actualToken.Extra("id_token").(string))
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, actualToken.Extra("id_token"))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -392,7 +429,7 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "should flag token refresh with id token is expired",
|
name: "should flag token refresh with id token is expired",
|
||||||
usr: &login.UserAuth{
|
usr: &login.UserAuth{
|
||||||
OAuthIdToken: EXPIRED_JWT,
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
||||||
},
|
},
|
||||||
expectedTokenRefreshFlag: true,
|
expectedTokenRefreshFlag: true,
|
||||||
expectedTokenDuration: time.Second,
|
expectedTokenDuration: time.Second,
|
||||||
@ -408,125 +445,10 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
token, needsTokenRefresh := needTokenRefresh(tt.usr)
|
token := buildOAuthTokenFromAuthInfo(tt.usr)
|
||||||
|
needsTokenRefresh := needTokenRefresh(context.Background(), token)
|
||||||
|
|
||||||
assert.NotNil(t, token)
|
|
||||||
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
|
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
|
|
||||||
timeNow := time.Now()
|
|
||||||
token := &oauth2.Token{
|
|
||||||
AccessToken: "oauth_access_token",
|
|
||||||
RefreshToken: "refresh_token_found",
|
|
||||||
Expiry: timeNow,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
type environment struct {
|
|
||||||
authInfoService *authinfotest.FakeService
|
|
||||||
serverLock *serverlock.ServerLockService
|
|
||||||
socialConnector *socialtest.MockSocialConnector
|
|
||||||
socialService *socialtest.FakeSocialService
|
|
||||||
|
|
||||||
service *Service
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
desc string
|
|
||||||
expectedErr error
|
|
||||||
expectedToken *oauth2.Token
|
|
||||||
usr *login.UserAuth
|
|
||||||
setup func(env *environment)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
|
|
||||||
usr: &login.UserAuth{
|
|
||||||
UserId: int64(1234),
|
|
||||||
AuthModule: login.SAMLAuthModule,
|
|
||||||
},
|
|
||||||
expectedErr: ErrNotAnOAuthProvider,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
|
|
||||||
usr: &login.UserAuth{
|
|
||||||
UserId: int64(1234),
|
|
||||||
AuthModule: login.GenericOAuthModule,
|
|
||||||
},
|
|
||||||
expectedErr: ErrNoRefreshTokenFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should not refresh token if the token is not expired",
|
|
||||||
usr: &login.UserAuth{
|
|
||||||
UserId: int64(1234),
|
|
||||||
AuthModule: login.GenericOAuthModule,
|
|
||||||
OAuthAccessToken: token.AccessToken,
|
|
||||||
OAuthRefreshToken: token.RefreshToken,
|
|
||||||
OAuthExpiry: timeNow.Add(time.Hour),
|
|
||||||
OAuthTokenType: "Bearer",
|
|
||||||
},
|
|
||||||
expectedToken: &oauth2.Token{
|
|
||||||
AccessToken: token.AccessToken,
|
|
||||||
RefreshToken: token.RefreshToken,
|
|
||||||
Expiry: timeNow.Add(time.Hour),
|
|
||||||
TokenType: "Bearer",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "should update saved token if the user auth has new access/refresh tokens",
|
|
||||||
usr: &login.UserAuth{
|
|
||||||
UserId: int64(1234),
|
|
||||||
AuthModule: login.GenericOAuthModule,
|
|
||||||
OAuthAccessToken: "new_oauth_access_token",
|
|
||||||
OAuthRefreshToken: "new_refresh_token_found",
|
|
||||||
OAuthExpiry: timeNow,
|
|
||||||
},
|
|
||||||
expectedToken: &oauth2.Token{
|
|
||||||
AccessToken: "oauth_access_token",
|
|
||||||
RefreshToken: "refresh_token_found",
|
|
||||||
Expiry: timeNow,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
},
|
|
||||||
setup: func(env *environment) {
|
|
||||||
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.desc, func(t *testing.T) {
|
|
||||||
socialConnector := &socialtest.MockSocialConnector{}
|
|
||||||
|
|
||||||
store := db.InitTestDB(t)
|
|
||||||
|
|
||||||
env := environment{
|
|
||||||
authInfoService: &authinfotest.FakeService{},
|
|
||||||
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
|
|
||||||
socialConnector: socialConnector,
|
|
||||||
socialService: &socialtest.FakeSocialService{
|
|
||||||
ExpectedConnector: socialConnector,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.setup != nil {
|
|
||||||
tt.setup(&env)
|
|
||||||
}
|
|
||||||
|
|
||||||
env.service = ProvideService(
|
|
||||||
env.socialService,
|
|
||||||
env.authInfoService,
|
|
||||||
setting.NewCfg(),
|
|
||||||
prometheus.NewRegistry(),
|
|
||||||
env.serverLock,
|
|
||||||
tracing.InitializeTracerForTest(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
|
|
||||||
|
|
||||||
if tt.expectedToken != nil {
|
|
||||||
assert.Equal(t, tt.expectedToken, token)
|
|
||||||
}
|
|
||||||
assert.ErrorIs(t, tt.expectedErr, err)
|
|
||||||
socialConnector.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -14,7 +14,7 @@ type MockOauthTokenService struct {
|
|||||||
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
|
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
|
||||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||||
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
||||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
|
InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester) error
|
||||||
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error)
|
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity.
|
|||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth) error {
|
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error {
|
||||||
if m.InvalidateOAuthTokensFunc != nil {
|
if m.InvalidateOAuthTokensFunc != nil {
|
||||||
return m.InvalidateOAuthTokensFunc(ctx, usr)
|
return m.InvalidateOAuthTokensFunc(ctx, usr)
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,6 @@ func (s *Service) TryTokenRefresh(context.Context, identity.Requester) (*oauth2.
|
|||||||
return s.Token, nil
|
return s.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) InvalidateOAuthTokens(context.Context, *login.UserAuth) error {
|
func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user