refactor: use golang.org/x/oauth2 pkce option (#80511)

Signed-off-by: junya koyama <arukiidou@yahoo.co.jp>
This commit is contained in:
arukiidou 2024-01-16 00:24:02 +09:00 committed by GitHub
parent 1cf53a34d1
commit bffb28c177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 14 deletions

View File

@ -112,7 +112,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
if err != nil { if err != nil {
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err) return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
} }
opts = append(opts, oauth2.SetAuthURLParam(codeVerifierParamName, pkceCookie.Value)) opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
} }
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
@ -184,16 +184,13 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
var plainPKCE string var plainPKCE string
if c.oauthCfg.UsePKCE { if c.oauthCfg.UsePKCE {
pkce, hashedPKCE, err := genPKCECode() verifier, err := genPKCECodeVerifier()
if err != nil { if err != nil {
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err) return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
} }
plainPKCE = pkce plainPKCE = verifier
opts = append(opts, opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
oauth2.SetAuthURLParam(codeChallengeParamName, hashedPKCE),
oauth2.SetAuthURLParam(codeChallengeMethodParamName, codeChallengeMethod),
)
} }
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret) state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
@ -233,8 +230,8 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login
return &authn.Redirect{URL: redirctURL}, true return &authn.Redirect{URL: redirctURL}, true
} }
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest. // genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.
func genPKCECode() (string, string, error) { func genPKCECodeVerifier() (string, error) {
// IETF RFC 7636 specifies that the code verifier should be 43-128 // IETF RFC 7636 specifies that the code verifier should be 43-128
// characters from a set of unreserved URI characters which is // characters from a set of unreserved URI characters which is
// almost the same as the set of characters in base64url. // almost the same as the set of characters in base64url.
@ -249,14 +246,12 @@ func genPKCECode() (string, string, error) {
raw := make([]byte, 96) raw := make([]byte, 96)
_, err := rand.Read(raw) _, err := rand.Read(raw)
if err != nil { if err != nil {
return "", "", err return "", err
} }
ascii := make([]byte, 128) ascii := make([]byte, 128)
base64.RawURLEncoding.Encode(ascii, raw) base64.RawURLEncoding.Encode(ascii, raw)
shasum := sha256.Sum256(ascii) return string(ascii), nil
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
return string(ascii), pkce, nil
} }
func genOAuthState(secret, seed string) (string, string, error) { func genOAuthState(secret, seed string) (string, string, error) {

View File

@ -268,7 +268,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
{ {
desc: "should generate redirect url with pkce if configured", desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true},
numCallOptions: 2, numCallOptions: 1,
authCodeUrlCalled: true, authCodeUrlCalled: true,
}, },
} }
@ -404,6 +404,12 @@ func TestOAuth_Logout(t *testing.T) {
} }
} }
func TestGenPKCECodeVerifier(t *testing.T) {
verifier, err := genPKCECodeVerifier()
assert.NoError(t, err)
assert.Len(t, verifier, 128)
}
type mockConnector struct { type mockConnector struct {
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
social.SocialConnector social.SocialConnector