Auth forwarding: Pass tokens without refresh (#61634)

* return only tokens from oauth

* feedback
This commit is contained in:
Jo 2023-01-18 10:50:35 +00:00 committed by GitHub
parent f25d5199c5
commit ecafb4dd15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 24 deletions

View File

@ -72,8 +72,13 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
if err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfoQuery.Result)
}
return nil
}
return token
}
@ -111,23 +116,43 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) err
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) {
logger.Debug("singleflight request for getting a new access token", "key", lockKey)
authProvider := usr.AuthModule
if !strings.Contains(authProvider, "oauth") {
logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId)
return nil, ErrNotAnOAuthProvider
}
if usr.OAuthRefreshToken == "" {
logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId)
return nil, ErrNoRefreshTokenFound
}
return o.tryGetOrRefreshAccessToken(ctx, usr)
})
return err
}
func buildOAuthTokenFromAuthInfo(authInfo *models.UserAuth) *oauth2.Token {
token := &oauth2.Token{
AccessToken: authInfo.OAuthAccessToken,
Expiry: authInfo.OAuthExpiry,
RefreshToken: authInfo.OAuthRefreshToken,
TokenType: authInfo.OAuthTokenType,
}
if authInfo.OAuthIdToken != "" {
token = token.WithExtra(map[string]interface{}{"id_token": authInfo.OAuthIdToken})
}
return token
}
func checkOAuthRefreshToken(authInfo *models.UserAuth) error {
if !strings.Contains(authInfo.AuthModule, "oauth") {
logger.Warn("the specified user's auth provider is not oauth",
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
return ErrNotAnOAuthProvider
}
if authInfo.OAuthRefreshToken == "" {
logger.Debug("no refresh token available",
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
return ErrNoRefreshTokenFound
}
return nil
}
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{
@ -143,6 +168,10 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAut
}
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) {
if err := checkOAuthRefreshToken(usr); err != nil {
return nil, err
}
authProvider := usr.AuthModule
connect, err := o.SocialService.GetConnector(authProvider)
if err != nil {
@ -157,21 +186,13 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.Us
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
persistedToken := &oauth2.Token{
AccessToken: usr.OAuthAccessToken,
Expiry: usr.OAuthExpiry,
RefreshToken: usr.OAuthRefreshToken,
TokenType: usr.OAuthTokenType,
}
if usr.OAuthIdToken != "" {
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken})
}
persistedToken := buildOAuthTokenFromAuthInfo(usr)
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
if err != nil {
logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err)
logger.Error("failed to retrieve oauth access token",
"provider", usr.AuthModule, "userId", usr.UserId, "error", err)
return nil, err
}