mirror of
https://github.com/grafana/grafana.git
synced 2025-02-14 01:23:32 -06:00
Auth forwarding: Pass tokens without refresh (#61634)
* return only tokens from oauth * feedback
This commit is contained in:
parent
f25d5199c5
commit
ecafb4dd15
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,4 +1,4 @@
|
||||
package usersync
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
@ -72,8 +72,13 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
|
||||
|
||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return buildOAuthTokenFromAuthInfo(authInfoQuery.Result)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
@ -111,23 +116,43 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) err
|
||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
|
||||
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) {
|
||||
logger.Debug("singleflight request for getting a new access token", "key", lockKey)
|
||||
authProvider := usr.AuthModule
|
||||
|
||||
if !strings.Contains(authProvider, "oauth") {
|
||||
logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||
return nil, ErrNotAnOAuthProvider
|
||||
}
|
||||
|
||||
if usr.OAuthRefreshToken == "" {
|
||||
logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||
return nil, ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
return o.tryGetOrRefreshAccessToken(ctx, usr)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func buildOAuthTokenFromAuthInfo(authInfo *models.UserAuth) *oauth2.Token {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: authInfo.OAuthAccessToken,
|
||||
Expiry: authInfo.OAuthExpiry,
|
||||
RefreshToken: authInfo.OAuthRefreshToken,
|
||||
TokenType: authInfo.OAuthTokenType,
|
||||
}
|
||||
|
||||
if authInfo.OAuthIdToken != "" {
|
||||
token = token.WithExtra(map[string]interface{}{"id_token": authInfo.OAuthIdToken})
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func checkOAuthRefreshToken(authInfo *models.UserAuth) error {
|
||||
if !strings.Contains(authInfo.AuthModule, "oauth") {
|
||||
logger.Warn("the specified user's auth provider is not oauth",
|
||||
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
|
||||
return ErrNotAnOAuthProvider
|
||||
}
|
||||
|
||||
if authInfo.OAuthRefreshToken == "" {
|
||||
logger.Debug("no refresh token available",
|
||||
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
|
||||
return ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
||||
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{
|
||||
@ -143,6 +168,10 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAut
|
||||
}
|
||||
|
||||
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) {
|
||||
if err := checkOAuthRefreshToken(usr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authProvider := usr.AuthModule
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
if err != nil {
|
||||
@ -157,21 +186,13 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.Us
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
|
||||
persistedToken := &oauth2.Token{
|
||||
AccessToken: usr.OAuthAccessToken,
|
||||
Expiry: usr.OAuthExpiry,
|
||||
RefreshToken: usr.OAuthRefreshToken,
|
||||
TokenType: usr.OAuthTokenType,
|
||||
}
|
||||
|
||||
if usr.OAuthIdToken != "" {
|
||||
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken})
|
||||
}
|
||||
persistedToken := buildOAuthTokenFromAuthInfo(usr)
|
||||
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
if err != nil {
|
||||
logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err)
|
||||
logger.Error("failed to retrieve oauth access token",
|
||||
"provider", usr.AuthModule, "userId", usr.UserId, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user