OAuth: persisting the id token (#42938)

* OAuth: persisting the id token

* OAuth: verifies that the idtoken gets persistet in the database
This commit is contained in:
Leonard Gram 2021-12-14 15:22:10 +01:00 committed by GitHub
parent 4a3961400a
commit 5d18834deb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 3 deletions

View File

@ -19,6 +19,7 @@ type UserAuth struct {
Created time.Time
OAuthAccessToken string
OAuthRefreshToken string
OAuthIdToken string
OAuthTokenType string
OAuthExpiry time.Time
}

View File

@ -71,9 +71,14 @@ func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthI
if err != nil {
return err
}
secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
if err != nil {
return err
}
userAuth.OAuthAccessToken = secretAccessToken
userAuth.OAuthRefreshToken = secretRefreshToken
userAuth.OAuthTokenType = secretTokenType
userAuth.OAuthIdToken = secretIdToken
query.Result = userAuth
return nil
@ -101,9 +106,18 @@ func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInf
return err
}
var secretIdToken string
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthIdToken = secretIdToken
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
@ -135,9 +149,18 @@ func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateA
return err
}
var secretIdToken string
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthIdToken = secretIdToken
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}

View File

@ -133,6 +133,8 @@ func TestUserAuth(t *testing.T) {
Expiry: time.Now(),
TokenType: "Bearer",
}
idToken := "testidtoken"
token = token.WithExtra(map[string]interface{}{"id_token": idToken})
// Find a user to set tokens on
login := "loginuser0"
@ -161,9 +163,10 @@ func TestUserAuth(t *testing.T) {
err = srv.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken)
require.Equal(t, getAuthQuery.Result.OAuthRefreshToken, token.RefreshToken)
require.Equal(t, getAuthQuery.Result.OAuthTokenType, token.TokenType)
require.Equal(t, token.AccessToken, getAuthQuery.Result.OAuthAccessToken)
require.Equal(t, token.RefreshToken, getAuthQuery.Result.OAuthRefreshToken)
require.Equal(t, token.TokenType, getAuthQuery.Result.OAuthTokenType)
require.Equal(t, idToken, getAuthQuery.Result.OAuthIdToken)
})
t.Run("Always return the most recently used auth_module", func(t *testing.T) {

View File

@ -68,6 +68,11 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
TokenType: authInfoQuery.Result.OAuthTokenType,
}
if authInfoQuery.Result.OAuthIdToken != "" {
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken})
}
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
if err != nil {

View File

@ -42,4 +42,8 @@ func addUserAuthMigrations(mg *Migrator) {
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{
Cols: []string{"user_id"},
}))
mg.AddMigration("Add OAuth ID token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_id_token", Type: DB_Text, Nullable: true,
}))
}