AuthN: Support reloading SSO config after the sso settings have changed (#80734)

* Add AuthNSvc reload handling

* Working, need to add test

* Remove commented out code

* Add Reload implementation to connectors

* Align and add tests, refactor

* Add more tests, linting

* Add extra checks + tests to oauth client

* Clean up based on reviews

* Move config instantiation into newSocialBase

* Use specific error
This commit is contained in:
Misi
2024-01-22 14:54:48 +01:00
committed by GitHub
parent 1f4a520b9d
commit 20bb0a3ab1
31 changed files with 889 additions and 217 deletions

View File

@@ -140,18 +140,8 @@ func ProvideService(
}
for name := range socialService.GetOAuthProviders() {
oauthCfg := socialService.GetOAuthInfoProvider(name)
if oauthCfg != nil && oauthCfg.Enabled {
clientName := authn.ClientWithPrefix(name)
connector, errConnector := socialService.GetConnector(name)
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
if errConnector != nil || errHTTPClient != nil {
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
} else {
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
}
}
clientName := authn.ClientWithPrefix(name)
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthTokenService, socialService))
}
// FIXME (jguer): move to User package

View File

@@ -8,7 +8,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
@@ -40,6 +39,9 @@ const (
)
var (
errOAuthClientDisabled = errutil.BadRequest("auth.oauth.disabled", errutil.WithPublicMessage("OAuth client is disabled"))
errOAuthInternal = errutil.Internal("auth.oauth.internal", errutil.WithPublicMessage("An internal error occurred in the OAuth client"))
errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
@@ -62,24 +64,25 @@ var _ authn.LogoutClient = new(OAuth)
var _ authn.RedirectClient = new(OAuth)
func ProvideOAuth(
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
name string, cfg *setting.Cfg, oauthService oauthtoken.OAuthTokenService,
socialService social.Service,
) *OAuth {
providerName := strings.TrimPrefix(name, "auth.client.")
return &OAuth{
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
name, fmt.Sprintf("oauth_%s", providerName), providerName,
log.New(name), cfg, oauthService, socialService,
}
}
type OAuth struct {
name string
moduleName string
providerName string
log log.Logger
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector
httpClient *http.Client
oauthService oauthtoken.OAuthTokenService
oauthService oauthtoken.OAuthTokenService
socialService social.Service
}
func (c *OAuth) Name() string {
@@ -88,6 +91,12 @@ func (c *OAuth) Name() string {
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
// get hashed state stored in cookie
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
if err != nil {
@@ -99,7 +108,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
}
// get state returned by the idp and hash it
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, oauthCfg.ClientSecret)
// compare the state returned by idp against the one we stored in cookie
if stateQuery != stateCookie.Value {
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
@@ -107,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
var opts []oauth2.AuthCodeOption
// if pkce is enabled for client validate we have the cookie and set it as url param
if c.oauthCfg.UsePKCE {
if oauthCfg.UsePKCE {
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
if err != nil {
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
@@ -115,15 +124,21 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
connector, errConnector := c.socialService.GetConnector(c.providerName)
httpClient, errHTTPClient := c.socialService.GetOAuthHttpClient(c.providerName)
if errConnector != nil || errHTTPClient != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth client: %w", c.name, errors.Join(errConnector, errHTTPClient))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// exchange auth code to a valid token
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
token, err := connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
if err != nil {
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
}
token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
userInfo, err := connector.UserInfo(ctx, connector.Client(clientCtx, token), token)
if err != nil {
var sErr *connectors.SocialError
if errors.As(err, &sErr) {
@@ -136,7 +151,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
}
if !c.connector.IsEmailAllowed(userInfo.Email) {
if !connector.IsEmailAllowed(userInfo.Email) {
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
}
@@ -167,7 +182,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
SyncTeams: true,
FetchSyncedUser: true,
SyncPermissions: true,
AllowSignUp: c.connector.IsSignupAllowed(),
AllowSignUp: connector.IsSignupAllowed(),
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
SyncOrgRoles: len(orgRoles) > 0,
LookUpParams: lookupParams,
@@ -178,12 +193,17 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
var opts []oauth2.AuthCodeOption
if c.oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
if oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, oauthCfg.HostedDomain))
}
var plainPKCE string
if c.oauthCfg.UsePKCE {
if oauthCfg.UsePKCE {
verifier, err := genPKCECodeVerifier()
if err != nil {
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
@@ -193,13 +213,18 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
}
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, oauthCfg.ClientSecret)
if err != nil {
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
}
connector, err := c.socialService.GetConnector(c.providerName)
if err != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth connector: %w", c.name, err)
}
return &authn.Redirect{
URL: c.connector.AuthCodeURL(state, opts...),
URL: connector.AuthCodeURL(state, opts...),
Extra: map[string]string{
authn.KeyOAuthState: hashedSate,
authn.KeyOAuthPKCE: plainPKCE,
@@ -215,19 +240,25 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
}
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
if redirctURL == "" {
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
c.log.FromContext(ctx).Debug("OAuth client is disabled")
return nil, false
}
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
if redirectURL == "" {
c.log.FromContext(ctx).Debug("No signout redirect url configured")
return nil, false
}
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
if isOICDLogout(redirectURL) && token != nil && token.Valid() {
if idToken, ok := token.Extra("id_token").(string); ok {
redirctURL = withIDTokenHint(redirctURL, idToken)
redirectURL = withIDTokenHint(redirectURL, idToken)
}
}
return &authn.Redirect{URL: redirctURL}, true
return &authn.Redirect{URL: redirectURL}, true
}
// genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
@@ -46,17 +47,23 @@ func TestOAuth_Authenticate(t *testing.T) {
{
desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
addStateCookie: true,
stateCookieValue: "",
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when the client is not enabled",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: false},
expectedErr: errOAuthClientDisabled,
},
{
desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{
@@ -64,7 +71,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState,
@@ -76,7 +83,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE,
@@ -88,7 +95,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -103,7 +110,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -119,7 +126,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -157,7 +164,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
allowInsecureTakeover: true,
addStateCookie: true,
stateCookieValue: "some-state",
@@ -211,12 +218,18 @@ func TestOAuth_Authenticate(t *testing.T) {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
}, nil, nil)
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc)
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
@@ -256,21 +269,27 @@ func TestOAuth_RedirectURL(t *testing.T) {
tests := []testCase{
{
desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should return error if the client is not enabled",
oauthCfg: &social.OAuthInfo{Enabled: false},
authCodeUrlCalled: false,
expectedErr: errOAuthClientDisabled,
},
}
for _, tt := range tests {
@@ -279,13 +298,18 @@ func TestOAuth_RedirectURL(t *testing.T) {
authCodeUrlCalled = false
)
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
},
}, nil, nil)
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), nil, fakeSocialSvc)
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
@@ -321,12 +345,17 @@ func TestOAuth_Logout(t *testing.T) {
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{},
},
{
desc: "should not return redirect url when client is not enabled",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{Enabled: false},
},
{
desc: "should return redirect url for globably configured redirect url",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
@@ -334,6 +363,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should return redirect url for client configured redirect url",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout",
},
expectedURL: "http://idp.com/logout",
@@ -345,6 +375,7 @@ func TestOAuth_Logout(t *testing.T) {
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp-2.com/logout",
},
expectedURL: "http://idp-2.com/logout",
@@ -354,6 +385,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should add id token hint if oicd logout is configured and token is valid",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
},
expectedURL: "http://idp.com/logout",
@@ -387,7 +419,10 @@ func TestOAuth_Logout(t *testing.T) {
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc)
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})