mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
OAuth: Support PKCE (#39948)
This commit is contained in:
parent
09587240ec
commit
e73cd2fdeb
@ -522,6 +522,7 @@ tls_skip_verify_insecure = false
|
||||
tls_client_cert =
|
||||
tls_client_key =
|
||||
tls_client_ca =
|
||||
use_pkce = false
|
||||
|
||||
#################################### Basic Auth ##########################
|
||||
[auth.basic]
|
||||
@ -735,7 +736,7 @@ global_alert_rule = -1
|
||||
enabled = false
|
||||
|
||||
# Comma-separated list of organization IDs for which to disable unified alerting. Only supported if unified alerting is enabled.
|
||||
disabled_orgs =
|
||||
disabled_orgs =
|
||||
|
||||
# Specify the frequency of polling for admin config changes.
|
||||
# The interval string is a possibly signed sequence of decimal numbers, followed by a unit suffix (ms, s, m, h, d), e.g. 30s or 1m.
|
||||
|
@ -501,6 +501,7 @@
|
||||
;tls_client_cert =
|
||||
;tls_client_key =
|
||||
;tls_client_ca =
|
||||
;use_pkce = false
|
||||
|
||||
#################################### Basic Auth ##########################
|
||||
[auth.basic]
|
||||
@ -712,7 +713,7 @@
|
||||
;enabled = false
|
||||
|
||||
# Comma-separated list of organization IDs for which to disable unified alerting. Only supported if unified alerting is enabled.
|
||||
;disabled_orgs =
|
||||
;disabled_orgs =
|
||||
|
||||
# Specify the frequency of polling for admin config changes.
|
||||
# The interval string is a possibly signed sequence of decimal numbers, followed by a unit suffix (ms, s, m, h, d), e.g. 30s or 1m.
|
||||
|
@ -24,8 +24,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
oauthLogger = log.New("oauth")
|
||||
oauthLogger = log.New("oauth")
|
||||
)
|
||||
|
||||
const (
|
||||
OauthStateCookieName = "oauth_state"
|
||||
OauthPKCECookieName = "oauth_code_verifier"
|
||||
)
|
||||
|
||||
func GenStateString() (string, error) {
|
||||
@ -37,6 +41,32 @@ func GenStateString() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(rnd), nil
|
||||
}
|
||||
|
||||
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
||||
func genPKCECode() (string, string, error) {
|
||||
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
||||
// characters from a set of unreserved URI characters which is
|
||||
// almost the same as the set of characters in base64url.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
|
||||
//
|
||||
// It doesn't hurt to generate a few more bytes here, we generate
|
||||
// 96 bytes which we then encode using base64url to make sure
|
||||
// they're within the set of unreserved characters.
|
||||
//
|
||||
// 96 is chosen because 96*8/6 = 128, which means that we'll have
|
||||
// 128 characters after it has been base64 encoded.
|
||||
raw := make([]byte, 96)
|
||||
_, err := rand.Read(raw)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
ascii := make([]byte, 128)
|
||||
base64.RawURLEncoding.Encode(ascii, raw)
|
||||
|
||||
shasum := sha256.Sum256(ascii)
|
||||
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
|
||||
return string(ascii), pkce, nil
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
loginInfo := models.LoginInfo{
|
||||
AuthModule: "oauth",
|
||||
@ -71,6 +101,26 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
|
||||
code := ctx.Query("code")
|
||||
if code == "" {
|
||||
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}
|
||||
|
||||
if provider.UsePKCE {
|
||||
ascii, pkce, err := genPKCECode()
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Generating PKCE failed", "error", err)
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "An internal error occurred",
|
||||
})
|
||||
}
|
||||
|
||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam("code_challenge", pkce),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
}
|
||||
|
||||
state, err := GenStateString()
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Generating state string failed", "err", err)
|
||||
@ -83,11 +133,11 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
|
||||
hashedState := hashStatecode(state, provider.ClientSecret)
|
||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
if provider.HostedDomain == "" {
|
||||
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
|
||||
} else {
|
||||
ctx.Redirect(connect.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", provider.HostedDomain), oauth2.AccessTypeOnline))
|
||||
if provider.HostedDomain != "" {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
|
||||
}
|
||||
|
||||
ctx.Redirect(connect.AuthCodeURL(state, opts...))
|
||||
return
|
||||
}
|
||||
|
||||
@ -125,9 +175,18 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
}
|
||||
|
||||
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
|
||||
opts := []oauth2.AuthCodeOption{}
|
||||
|
||||
codeVerifier := ctx.GetCookie(OauthPKCECookieName)
|
||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
if codeVerifier != "" {
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
// get token from provider
|
||||
token, err := connect.Exchange(oauthCtx, code)
|
||||
token, err := connect.Exchange(oauthCtx, code, opts...)
|
||||
if err != nil {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
|
161
pkg/api/login_oauth_test.go
Normal file
161
pkg/api/login_oauth_test.go
Normal file
@ -0,0 +1,161 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/macaron.v1"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/hooks"
|
||||
"github.com/grafana/grafana/pkg/services/licensing"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *macaron.Macaron {
|
||||
t.Helper()
|
||||
|
||||
if cfg == nil {
|
||||
cfg = setting.NewCfg()
|
||||
}
|
||||
cfg.ErrTemplateName = "error-template"
|
||||
|
||||
sqlStore := sqlstore.InitTestDB(t)
|
||||
|
||||
hs := &HTTPServer{
|
||||
Cfg: cfg,
|
||||
Bus: bus.GetBus(),
|
||||
License: &licensing.OSSLicensingService{Cfg: cfg},
|
||||
SQLStore: sqlStore,
|
||||
SocialService: social.ProvideService(cfg),
|
||||
HooksService: hooks.ProvideService(),
|
||||
}
|
||||
|
||||
m := macaron.New()
|
||||
m.Use(getContextHandler(t, cfg).Middleware)
|
||||
viewPath, err := filepath.Abs("../../public/views")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.UseMiddleware(macaron.Renderer(viewPath, "[[", "]]"))
|
||||
|
||||
m.Get("/login/:name", routing.Wrap(hs.OAuthLogin))
|
||||
return m
|
||||
}
|
||||
|
||||
func TestOAuthLogin_UnknownProvider(t *testing.T) {
|
||||
m := setupOAuthTest(t, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/notaprovider", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "OAuth not enabled")
|
||||
}
|
||||
|
||||
func TestOAuthLogin_Base(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
||||
_, err := sec.NewKey("enabled", "true")
|
||||
require.NoError(t, err)
|
||||
|
||||
m := setupOAuthTest(t, cfg)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.NotEmpty(t, location)
|
||||
|
||||
u, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, u.Query().Has("code_challenge"))
|
||||
assert.False(t, u.Query().Has("code_challenge_method"))
|
||||
|
||||
resp := recorder.Result()
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
cookies := resp.Cookies()
|
||||
var stateCookie *http.Cookie
|
||||
for _, c := range cookies {
|
||||
if c.Name == OauthStateCookieName {
|
||||
stateCookie = c
|
||||
}
|
||||
}
|
||||
require.NotNil(t, stateCookie)
|
||||
|
||||
req = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
(&url.URL{
|
||||
Path: "/login/generic_oauth",
|
||||
RawQuery: url.Values{
|
||||
"code": []string{"helloworld"},
|
||||
"state": []string{u.Query().Get("state")},
|
||||
}.Encode(),
|
||||
}).String(),
|
||||
nil,
|
||||
)
|
||||
req.AddCookie(stateCookie)
|
||||
recorder = httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
// TODO: validate that 'creating a token works'
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "login.OAuthLogin(NewTransportWithCode)")
|
||||
}
|
||||
|
||||
func TestOAuthLogin_UsePKCE(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
||||
_, err := sec.NewKey("enabled", "true")
|
||||
require.NoError(t, err)
|
||||
_, err = sec.NewKey("use_pkce", "true")
|
||||
require.NoError(t, err)
|
||||
|
||||
m := setupOAuthTest(t, cfg)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.NotEmpty(t, location)
|
||||
|
||||
u, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, u.Query().Has("code_challenge"))
|
||||
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
|
||||
|
||||
resp := recorder.Result()
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
var oauthCookie *http.Cookie
|
||||
for _, cookie := range resp.Cookies() {
|
||||
if cookie.Name == OauthPKCECookieName {
|
||||
oauthCookie = cookie
|
||||
}
|
||||
}
|
||||
require.NotNil(t, oauthCookie)
|
||||
|
||||
shasum := sha256.Sum256([]byte(oauthCookie.Value))
|
||||
assert.Equal(
|
||||
t,
|
||||
u.Query().Get("code_challenge"),
|
||||
base64.RawURLEncoding.EncodeToString(shasum[:]),
|
||||
)
|
||||
}
|
@ -49,6 +49,7 @@ type OAuthInfo struct {
|
||||
TlsClientKey string
|
||||
TlsClientCa string
|
||||
TlsSkipVerify bool
|
||||
UsePKCE bool
|
||||
}
|
||||
|
||||
func ProvideService(cfg *setting.Cfg) *SocialService {
|
||||
@ -84,6 +85,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
|
||||
TlsClientKey: sec.Key("tls_client_key").String(),
|
||||
TlsClientCa: sec.Key("tls_client_ca").String(),
|
||||
TlsSkipVerify: sec.Key("tls_skip_verify_insecure").MustBool(),
|
||||
UsePKCE: sec.Key("use_pkce").MustBool(),
|
||||
}
|
||||
|
||||
// when empty_scopes parameter exists and is true, overwrite scope with empty value
|
||||
|
Loading…
Reference in New Issue
Block a user