mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
OAuth: persisting the id token (#42938)
* OAuth: persisting the id token * OAuth: verifies that the idtoken gets persistet in the database
This commit is contained in:
parent
4a3961400a
commit
5d18834deb
@ -19,6 +19,7 @@ type UserAuth struct {
|
||||
Created time.Time
|
||||
OAuthAccessToken string
|
||||
OAuthRefreshToken string
|
||||
OAuthIdToken string
|
||||
OAuthTokenType string
|
||||
OAuthExpiry time.Time
|
||||
}
|
||||
|
@ -71,9 +71,14 @@ func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthI
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.OAuthAccessToken = secretAccessToken
|
||||
userAuth.OAuthRefreshToken = secretRefreshToken
|
||||
userAuth.OAuthTokenType = secretTokenType
|
||||
userAuth.OAuthIdToken = secretIdToken
|
||||
|
||||
query.Result = userAuth
|
||||
return nil
|
||||
@ -101,9 +106,18 @@ func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInf
|
||||
return err
|
||||
}
|
||||
|
||||
var secretIdToken string
|
||||
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
|
||||
secretIdToken, err = s.encryptAndEncode(idToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
authUser.OAuthAccessToken = secretAccessToken
|
||||
authUser.OAuthRefreshToken = secretRefreshToken
|
||||
authUser.OAuthTokenType = secretTokenType
|
||||
authUser.OAuthIdToken = secretIdToken
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
@ -135,9 +149,18 @@ func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateA
|
||||
return err
|
||||
}
|
||||
|
||||
var secretIdToken string
|
||||
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
|
||||
secretIdToken, err = s.encryptAndEncode(idToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
authUser.OAuthAccessToken = secretAccessToken
|
||||
authUser.OAuthRefreshToken = secretRefreshToken
|
||||
authUser.OAuthTokenType = secretTokenType
|
||||
authUser.OAuthIdToken = secretIdToken
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
|
@ -133,6 +133,8 @@ func TestUserAuth(t *testing.T) {
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
idToken := "testidtoken"
|
||||
token = token.WithExtra(map[string]interface{}{"id_token": idToken})
|
||||
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
@ -161,9 +163,10 @@ func TestUserAuth(t *testing.T) {
|
||||
err = srv.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
|
||||
require.Equal(t, getAuthQuery.Result.OAuthRefreshToken, token.RefreshToken)
|
||||
require.Equal(t, getAuthQuery.Result.OAuthTokenType, token.TokenType)
|
||||
require.Equal(t, token.AccessToken, getAuthQuery.Result.OAuthAccessToken)
|
||||
require.Equal(t, token.RefreshToken, getAuthQuery.Result.OAuthRefreshToken)
|
||||
require.Equal(t, token.TokenType, getAuthQuery.Result.OAuthTokenType)
|
||||
require.Equal(t, idToken, getAuthQuery.Result.OAuthIdToken)
|
||||
})
|
||||
|
||||
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
||||
|
@ -68,6 +68,11 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
|
||||
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
|
||||
TokenType: authInfoQuery.Result.OAuthTokenType,
|
||||
}
|
||||
|
||||
if authInfoQuery.Result.OAuthIdToken != "" {
|
||||
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken})
|
||||
}
|
||||
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
if err != nil {
|
||||
|
@ -42,4 +42,8 @@ func addUserAuthMigrations(mg *Migrator) {
|
||||
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{
|
||||
Cols: []string{"user_id"},
|
||||
}))
|
||||
|
||||
mg.AddMigration("Add OAuth ID token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
|
||||
Name: "o_auth_id_token", Type: DB_Text, Nullable: true,
|
||||
}))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user