OAuth: Support PKCE (#39948)

This commit is contained in:
Emil Tullstedt 2021-10-13 16:45:15 +02:00 committed by GitHub
parent 09587240ec
commit e73cd2fdeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 232 additions and 8 deletions

View File

@ -522,6 +522,7 @@ tls_skip_verify_insecure = false
tls_client_cert =
tls_client_key =
tls_client_ca =
use_pkce = false
#################################### Basic Auth ##########################
[auth.basic]
@ -735,7 +736,7 @@ global_alert_rule = -1
enabled = false
# Comma-separated list of organization IDs for which to disable unified alerting. Only supported if unified alerting is enabled.
disabled_orgs =
disabled_orgs =
# Specify the frequency of polling for admin config changes.
# The interval string is a possibly signed sequence of decimal numbers, followed by a unit suffix (ms, s, m, h, d), e.g. 30s or 1m.

View File

@ -501,6 +501,7 @@
;tls_client_cert =
;tls_client_key =
;tls_client_ca =
;use_pkce = false
#################################### Basic Auth ##########################
[auth.basic]
@ -712,7 +713,7 @@
;enabled = false
# Comma-separated list of organization IDs for which to disable unified alerting. Only supported if unified alerting is enabled.
;disabled_orgs =
;disabled_orgs =
# Specify the frequency of polling for admin config changes.
# The interval string is a possibly signed sequence of decimal numbers, followed by a unit suffix (ms, s, m, h, d), e.g. 30s or 1m.

View File

@ -24,8 +24,12 @@ import (
)
var (
oauthLogger = log.New("oauth")
oauthLogger = log.New("oauth")
)
const (
OauthStateCookieName = "oauth_state"
OauthPKCECookieName = "oauth_code_verifier"
)
func GenStateString() (string, error) {
@ -37,6 +41,32 @@ func GenStateString() (string, error) {
return base64.URLEncoding.EncodeToString(rnd), nil
}
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
func genPKCECode() (string, string, error) {
// IETF RFC 7636 specifies that the code verifier should be 43-128
// characters from a set of unreserved URI characters which is
// almost the same as the set of characters in base64url.
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
//
// It doesn't hurt to generate a few more bytes here, we generate
// 96 bytes which we then encode using base64url to make sure
// they're within the set of unreserved characters.
//
// 96 is chosen because 96*8/6 = 128, which means that we'll have
// 128 characters after it has been base64 encoded.
raw := make([]byte, 96)
_, err := rand.Read(raw)
if err != nil {
return "", "", err
}
ascii := make([]byte, 128)
base64.RawURLEncoding.Encode(ascii, raw)
shasum := sha256.Sum256(ascii)
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
return string(ascii), pkce, nil
}
func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
loginInfo := models.LoginInfo{
AuthModule: "oauth",
@ -71,6 +101,26 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
code := ctx.Query("code")
if code == "" {
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}
if provider.UsePKCE {
ascii, pkce, err := genPKCECode()
if err != nil {
ctx.Logger.Error("Generating PKCE failed", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred",
})
}
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
opts = append(opts,
oauth2.SetAuthURLParam("code_challenge", pkce),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
}
state, err := GenStateString()
if err != nil {
ctx.Logger.Error("Generating state string failed", "err", err)
@ -83,11 +133,11 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
hashedState := hashStatecode(state, provider.ClientSecret)
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if provider.HostedDomain == "" {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
} else {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", provider.HostedDomain), oauth2.AccessTypeOnline))
if provider.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
}
ctx.Redirect(connect.AuthCodeURL(state, opts...))
return
}
@ -125,9 +175,18 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
}
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
opts := []oauth2.AuthCodeOption{}
codeVerifier := ctx.GetCookie(OauthPKCECookieName)
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
if codeVerifier != "" {
opts = append(opts,
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
)
}
// get token from provider
token, err := connect.Exchange(oauthCtx, code)
token, err := connect.Exchange(oauthCtx, code, opts...)
if err != nil {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,

161
pkg/api/login_oauth_test.go Normal file
View File

@ -0,0 +1,161 @@
package api
import (
"crypto/sha256"
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
)
func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *macaron.Macaron {
t.Helper()
if cfg == nil {
cfg = setting.NewCfg()
}
cfg.ErrTemplateName = "error-template"
sqlStore := sqlstore.InitTestDB(t)
hs := &HTTPServer{
Cfg: cfg,
Bus: bus.GetBus(),
License: &licensing.OSSLicensingService{Cfg: cfg},
SQLStore: sqlStore,
SocialService: social.ProvideService(cfg),
HooksService: hooks.ProvideService(),
}
m := macaron.New()
m.Use(getContextHandler(t, cfg).Middleware)
viewPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
m.UseMiddleware(macaron.Renderer(viewPath, "[[", "]]"))
m.Get("/login/:name", routing.Wrap(hs.OAuthLogin))
return m
}
func TestOAuthLogin_UnknownProvider(t *testing.T) {
m := setupOAuthTest(t, nil)
req := httptest.NewRequest(http.MethodGet, "/login/notaprovider", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusNotFound, recorder.Code)
assert.Contains(t, recorder.Body.String(), "OAuth not enabled")
}
func TestOAuthLogin_Base(t *testing.T) {
cfg := setting.NewCfg()
sec := cfg.Raw.Section("auth.generic_oauth")
_, err := sec.NewKey("enabled", "true")
require.NoError(t, err)
m := setupOAuthTest(t, cfg)
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.NotEmpty(t, location)
u, err := url.Parse(location)
require.NoError(t, err)
assert.False(t, u.Query().Has("code_challenge"))
assert.False(t, u.Query().Has("code_challenge_method"))
resp := recorder.Result()
require.NoError(t, resp.Body.Close())
cookies := resp.Cookies()
var stateCookie *http.Cookie
for _, c := range cookies {
if c.Name == OauthStateCookieName {
stateCookie = c
}
}
require.NotNil(t, stateCookie)
req = httptest.NewRequest(
http.MethodGet,
(&url.URL{
Path: "/login/generic_oauth",
RawQuery: url.Values{
"code": []string{"helloworld"},
"state": []string{u.Query().Get("state")},
}.Encode(),
}).String(),
nil,
)
req.AddCookie(stateCookie)
recorder = httptest.NewRecorder()
m.ServeHTTP(recorder, req)
// TODO: validate that 'creating a token works'
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "login.OAuthLogin(NewTransportWithCode)")
}
func TestOAuthLogin_UsePKCE(t *testing.T) {
cfg := setting.NewCfg()
sec := cfg.Raw.Section("auth.generic_oauth")
_, err := sec.NewKey("enabled", "true")
require.NoError(t, err)
_, err = sec.NewKey("use_pkce", "true")
require.NoError(t, err)
m := setupOAuthTest(t, cfg)
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.NotEmpty(t, location)
u, err := url.Parse(location)
require.NoError(t, err)
assert.True(t, u.Query().Has("code_challenge"))
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
resp := recorder.Result()
require.NoError(t, resp.Body.Close())
var oauthCookie *http.Cookie
for _, cookie := range resp.Cookies() {
if cookie.Name == OauthPKCECookieName {
oauthCookie = cookie
}
}
require.NotNil(t, oauthCookie)
shasum := sha256.Sum256([]byte(oauthCookie.Value))
assert.Equal(
t,
u.Query().Get("code_challenge"),
base64.RawURLEncoding.EncodeToString(shasum[:]),
)
}

View File

@ -49,6 +49,7 @@ type OAuthInfo struct {
TlsClientKey string
TlsClientCa string
TlsSkipVerify bool
UsePKCE bool
}
func ProvideService(cfg *setting.Cfg) *SocialService {
@ -84,6 +85,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
TlsClientKey: sec.Key("tls_client_key").String(),
TlsClientCa: sec.Key("tls_client_ca").String(),
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
UsePKCE: sec.Key("use_pkce").MustBool(),
}
// when empty_scopes parameter exists and is true, overwrite scope with empty value