From 8cb351e54a7a051d155a1f847f881a4b65c3b5b2 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Tue, 19 Dec 2023 10:17:28 +0100 Subject: [PATCH] Authn: Handle logout logic in auth broker (#79635) * AuthN: Add new client extension interface that allows for custom logout logic * AuthN: Add tests for oauth client logout * Call authn.Logout Co-authored-by: Gabriel MABILLE --- pkg/api/login.go | 119 +++---------------- pkg/services/authn/authn.go | 11 ++ pkg/services/authn/authnimpl/service.go | 75 ++++++++++-- pkg/services/authn/authnimpl/service_test.go | 93 +++++++++++++++ pkg/services/authn/authntest/fake.go | 6 + pkg/services/authn/authntest/mock.go | 16 +++ pkg/services/authn/clients/oauth.go | 93 +++++++++++++-- pkg/services/authn/clients/oauth_test.go | 109 ++++++++++++++++- 8 files changed, 395 insertions(+), 127 deletions(-) diff --git a/pkg/api/login.go b/pkg/api/login.go index e9e88cf3841..9f6376a5469 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -29,8 +29,6 @@ import ( const ( viewIndex = "index" loginErrorCookieName = "login_error" - // #nosec G101 - this is not a hardcoded secret - postLogoutRedirectParam = "post_logout_redirect_uri" ) var setIndexViewData = (*HTTPServer).setIndexViewData @@ -243,70 +241,31 @@ func (hs *HTTPServer) loginUserWithUser(user *user.User, c *contextmodel.ReqCont } func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) { - userID, errID := identity.UserIdentifier(c.SignedInUser.GetNamespacedID()) - if errID != nil { - hs.log.Error("failed to retrieve user ID", "error", errID) - } - - oauthProviderSignoutRedirectUrl := "" - getAuthQuery := loginservice.GetAuthInfoQuery{UserId: userID} - authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery) - if err == nil { - // If SAML is enabled and this is a SAML user use saml logout - if hs.samlSingleLogoutEnabled() { - if authInfo.AuthModule == loginservice.SAMLAuthModule { - c.Redirect(hs.Cfg.AppSubURL + "/logout/saml") - return - } + // FIXME: restructure saml client to implement authn.LogoutClient + if hs.samlSingleLogoutEnabled() { + id, err := identity.UserIdentifier(c.SignedInUser.GetNamespacedID()) + if err != nil { + hs.log.Error("failed to retrieve user ID", "error", err) } - oauthProvider := hs.SocialService.GetOAuthInfoProvider(strings.TrimPrefix(authInfo.AuthModule, "oauth_")) - if oauthProvider != nil { - oauthProviderSignoutRedirectUrl = oauthProvider.SignoutRedirectUrl + + authInfo, _ := hs.authInfoService.GetAuthInfo(c.Req.Context(), &loginservice.GetAuthInfoQuery{UserId: id}) + if authInfo != nil && authInfo.AuthModule == loginservice.SAMLAuthModule { + c.Redirect(hs.Cfg.AppSubURL + "/logout/saml") + return } } - hs.log.Debug("Logout Redirect url", "auth.SignoutRedirectUrl:", hs.Cfg.SignoutRedirectUrl) - hs.log.Debug("Logout Redirect url", "oauth provider redirect url:", oauthProviderSignoutRedirectUrl) - - signOutRedirectUrl := getSignOutRedirectUrl(hs.Cfg.SignoutRedirectUrl, oauthProviderSignoutRedirectUrl) - - hs.log.Debug("Logout Redirect url", "signOurRedirectUrl:", signOutRedirectUrl) - idTokenHint := "" - oidcLogout := isPostLogoutRedirectConfigured(signOutRedirectUrl) - - // Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one - if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists { - token := hs.oauthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser) - if oidcLogout { - if token.Valid() { - idTokenHint = token.Extra("id_token").(string) - } else { - hs.log.Warn("Token is not valid") - } - } - - if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil { - hs.log.Warn("failed to invalidate oauth tokens for user", "userId", userID, "error", err) - } - } - - err = hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false) - if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) { - hs.log.Error("failed to revoke auth token", "error", err) - } - + redirect, err := hs.authnService.Logout(c.Req.Context(), c.SignedInUser, c.UserToken) authn.DeleteSessionCookie(c.Resp, hs.Cfg) - rdUrl := signOutRedirectUrl - if rdUrl != "" { - if oidcLogout { - rdUrl = getPostRedirectUrl(signOutRedirectUrl, idTokenHint) - } - c.Redirect(rdUrl) - } else { - hs.log.Info("Successful Logout", "User", c.SignedInUser.GetEmail()) + if err != nil { + hs.log.Error("Failed perform proper logout", "error", err) c.Redirect(hs.Cfg.AppSubURL + "/login") } + + _, id := c.SignedInUser.GetNamespacedID() + hs.log.Info("Successful Logout", "userID", id) + c.Redirect(redirect.URL) } func (hs *HTTPServer) tryGetEncryptedCookie(ctx *contextmodel.ReqContext, cookieName string) (string, bool) { @@ -420,47 +379,3 @@ func getFirstPublicErrorMessage(err *errutil.Error) string { return errPublic.Message } - -func isPostLogoutRedirectConfigured(redirectUrl string) bool { - if redirectUrl == "" { - return false - } - - u, err := url.Parse(redirectUrl) - if err != nil { - return false - } - - q := u.Query() - _, ok := q[postLogoutRedirectParam] - return ok -} - -func getPostRedirectUrl(rdUrl string, tokenHint string) string { - if tokenHint == "" { - return rdUrl - } - if rdUrl == "" { - return rdUrl - } - - u, err := url.Parse(rdUrl) - if err != nil { - return rdUrl - } - - q := u.Query() - q.Set("id_token_hint", tokenHint) - u.RawQuery = q.Encode() - - return u.String() -} - -func getSignOutRedirectUrl(gRdUrl string, oauthProviderUrl string) string { - if oauthProviderUrl != "" { - return oauthProviderUrl - } else if gRdUrl != "" { - return gRdUrl - } - return "" -} diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index 2fd46dc8114..12ac3b826c1 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -11,6 +11,7 @@ import ( "github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models/usertoken" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" @@ -74,6 +75,8 @@ type Service interface { RegisterPostLoginHook(hook PostLoginHookFn, priority uint) // RedirectURL will generate url that we can use to initiate auth flow for supported clients. RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error) + // Logout revokes session token and does additional clean up if client used to authenticate supports it + Logout(ctx context.Context, user identity.Requester, sessionToken *usertoken.UserToken) (*Redirect, error) // RegisterClient will register a new authn.Client that can be used for authentication RegisterClient(c Client) } @@ -115,6 +118,14 @@ type RedirectClient interface { RedirectURL(ctx context.Context, r *Request) (*Redirect, error) } +// LogoutCLient is an optional interface that auth client can implement. +// Clients that implements this interface can implement additional logic +// that should happen during logout and supports client specific redirect URL. +type LogoutClient interface { + Client + Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*Redirect, bool) +} + type PasswordClient interface { AuthenticatePassword(ctx context.Context, r *Request, username, password string) (*Identity, error) } diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index ac191c0e299..adf0ad0e980 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "strconv" + "strings" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/attribute" @@ -19,6 +20,7 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/apikey" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authnimpl/sync" "github.com/grafana/grafana/pkg/services/authn/clients" @@ -73,15 +75,16 @@ func ProvideService( signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server, ) *Service { s := &Service{ - log: log.New("authn.service"), - cfg: cfg, - clients: make(map[string]authn.Client), - clientQueue: newQueue[authn.ContextAwareClient](), - tracer: tracer, - metrics: newMetrics(registerer), - sessionService: sessionService, - postAuthHooks: newQueue[authn.PostAuthHookFn](), - postLoginHooks: newQueue[authn.PostLoginHookFn](), + log: log.New("authn.service"), + cfg: cfg, + clients: make(map[string]authn.Client), + clientQueue: newQueue[authn.ContextAwareClient](), + tracer: tracer, + metrics: newMetrics(registerer), + authInfoService: authInfoService, + sessionService: sessionService, + postAuthHooks: newQueue[authn.PostAuthHookFn](), + postLoginHooks: newQueue[authn.PostLoginHookFn](), } usageStats.RegisterMetricsFunc(s.getUsageStats) @@ -146,7 +149,7 @@ func ProvideService( if errConnector != nil || errHTTPClient != nil { s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient)) } else { - s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient)) + s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService)) } } } @@ -175,7 +178,8 @@ type Service struct { tracer tracing.Tracer metrics *metrics - sessionService auth.UserTokenService + authInfoService login.AuthInfoService + sessionService auth.UserTokenService // postAuthHooks are called after a successful authentication. They can modify the identity. postAuthHooks *queue[authn.PostAuthHookFn] @@ -335,6 +339,55 @@ func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Reque return redirectClient.RedirectURL(ctx, r) } +func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionToken *auth.UserToken) (*authn.Redirect, error) { + ctx, span := s.tracer.Start(ctx, "authn.Logout") + defer span.End() + + redirect := &authn.Redirect{URL: s.cfg.AppSubURL + "/login"} + + namespace, id := user.GetNamespacedID() + if namespace != authn.NamespaceUser { + return redirect, nil + } + + userID, err := identity.IntIdentifier(namespace, id) + if err != nil { + s.log.FromContext(ctx).Debug("Invalid user id", "id", userID, "err", err) + return redirect, nil + } + + info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID}) + if info != nil { + client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_")) + + c, ok := s.clients[client] + if !ok { + s.log.FromContext(ctx).Debug("No client configured for auth module", "client", client) + goto Default + } + + logoutClient, ok := c.(authn.LogoutClient) + if !ok { + s.log.FromContext(ctx).Debug("Client do not support specialized logout logic", "client", client) + goto Default + } + + clientRedirect, ok := logoutClient.Logout(ctx, user, info) + if !ok { + goto Default + } + + redirect = clientRedirect + } + +Default: + if err = s.sessionService.RevokeToken(ctx, sessionToken, false); err != nil { + return nil, err + } + + return redirect, nil +} + func (s *Service) RegisterClient(c authn.Client) { s.clients[c.Name()] = c if cac, ok := c.(authn.ContextAwareClient); ok { diff --git a/pkg/services/authn/authnimpl/service_test.go b/pkg/services/authn/authnimpl/service_test.go index 8e253d23dd9..b053569d5ee 100644 --- a/pkg/services/authn/authnimpl/service_test.go +++ b/pkg/services/authn/authnimpl/service_test.go @@ -13,10 +13,14 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authntest" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/login/authinfotest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -299,6 +303,95 @@ func TestService_RedirectURL(t *testing.T) { } } +func TestService_Logout(t *testing.T) { + type TestCase struct { + desc string + + identity *authn.Identity + sessionToken *usertoken.UserToken + info *login.UserAuth + + client authn.Client + + expectedErr error + expectedTokenRevoked bool + expectedRedirect *authn.Redirect + } + + tests := []TestCase{ + { + desc: "should redirect to default redirect url when identity is not a user", + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceServiceAccount, 1)}, + expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, + }, + { + desc: "should redirect to default redirect url when no external provider was used to authenticate", + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, + expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, + expectedTokenRevoked: true, + }, + { + desc: "should redirect to default redirect url when client is not found", + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, + info: &login.UserAuth{AuthModule: "notFound"}, + expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, + expectedTokenRevoked: true, + }, + { + desc: "should redirect to default redirect url when client do not implement logout extension", + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, + info: &login.UserAuth{AuthModule: "azuread"}, + expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, + client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"}, + expectedTokenRevoked: true, + }, + { + desc: "should redirect to client specific url", + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, + info: &login.UserAuth{AuthModule: "azuread"}, + expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"}, + client: &authntest.MockClient{ + NameFunc: func() string { return "auth.client.azuread" }, + LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) { + return &authn.Redirect{URL: "http://idp.com/logout"}, true + }, + }, + expectedTokenRevoked: true, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var tokenRevoked bool + + s := setupTests(t, func(svc *Service) { + if tt.client != nil { + svc.RegisterClient(tt.client) + } + svc.cfg.AppSubURL = "http://localhost:3000" + svc.authInfoService = &authinfotest.FakeService{ + ExpectedUserAuth: tt.info, + } + + svc.sessionService = &authtest.FakeUserAuthTokenService{ + RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error { + tokenRevoked = true + assert.EqualValues(t, tt.sessionToken, sessionToken) + assert.False(t, soft) + return nil + }, + } + }) + + redirect, err := s.Logout(context.Background(), tt.identity, tt.sessionToken) + + assert.ErrorIs(t, err, tt.expectedErr) + assert.EqualValues(t, tt.expectedRedirect, redirect) + assert.Equal(t, tt.expectedTokenRevoked, tokenRevoked) + }) + } +} + func mustParseURL(s string) *url.URL { u, err := url.Parse(s) if err != nil { diff --git a/pkg/services/authn/authntest/fake.go b/pkg/services/authn/authntest/fake.go index 3c7af4060f1..81a93636585 100644 --- a/pkg/services/authn/authntest/fake.go +++ b/pkg/services/authn/authntest/fake.go @@ -3,6 +3,8 @@ package authntest import ( "context" + "github.com/grafana/grafana/pkg/models/usertoken" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" ) @@ -66,6 +68,10 @@ func (f *FakeService) RedirectURL(ctx context.Context, client string, r *authn.R return f.ExpectedRedirect, f.ExpectedErr } +func (*FakeService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) { + panic("unimplemented") +} + func (f *FakeService) RegisterClient(c authn.Client) {} func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error { diff --git a/pkg/services/authn/authntest/mock.go b/pkg/services/authn/authntest/mock.go index 0868fad5aaa..2a46f97a5c8 100644 --- a/pkg/services/authn/authntest/mock.go +++ b/pkg/services/authn/authntest/mock.go @@ -3,7 +3,10 @@ package authntest import ( "context" + "github.com/grafana/grafana/pkg/models/usertoken" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" ) var _ authn.Service = new(MockService) @@ -40,6 +43,10 @@ func (m *MockService) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority panic("unimplemented") } +func (*MockService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) { + panic("unimplemented") +} + func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error { if m.SyncIdentityFunc != nil { return m.SyncIdentityFunc(ctx, identity) @@ -48,6 +55,7 @@ func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity } var _ authn.HookClient = new(MockClient) +var _ authn.LogoutClient = new(MockClient) var _ authn.ContextAwareClient = new(MockClient) type MockClient struct { @@ -56,6 +64,7 @@ type MockClient struct { TestFunc func(ctx context.Context, r *authn.Request) bool PriorityFunc func() uint HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error + LogoutFunc func(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) } func (m MockClient) Name() string { @@ -93,6 +102,13 @@ func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn return nil } +func (m *MockClient) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) { + if m.LogoutFunc != nil { + return m.LogoutFunc(ctx, user, info) + } + return nil, false +} + var _ authn.ProxyClient = new(MockProxyClient) type MockProxyClient struct { diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index 55fcca99c63..c27a63b1d03 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "strings" "golang.org/x/oauth2" @@ -16,8 +17,10 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social/connectors" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util/errutil" @@ -30,9 +33,10 @@ const ( codeChallengeMethodParamName = "code_challenge_method" codeChallengeMethod = "S256" - oauthStateQueryName = "state" - oauthStateCookieName = "oauth_state" - oauthPKCECookieName = "oauth_code_verifier" + oauthStateQueryName = "state" + oauthStateCookieName = "oauth_state" + oauthPKCECookieName = "oauth_code_verifier" + oauthPostLogoutRedirectParam = "post_logout_redirect_uri" ) var ( @@ -54,26 +58,28 @@ func fromSocialErr(err *connectors.SocialError) error { return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err) } +var _ authn.LogoutClient = new(OAuth) var _ authn.RedirectClient = new(OAuth) func ProvideOAuth( name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo, - connector social.SocialConnector, httpClient *http.Client, + connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService, ) *OAuth { return &OAuth{ name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")), - log.New(name), cfg, oauthCfg, connector, httpClient, + log.New(name), cfg, oauthCfg, connector, httpClient, oauthService, } } type OAuth struct { - name string - moduleName string - log log.Logger - cfg *setting.Cfg - oauthCfg *social.OAuthInfo - connector social.SocialConnector - httpClient *http.Client + name string + moduleName string + log log.Logger + cfg *setting.Cfg + oauthCfg *social.OAuthInfo + connector social.SocialConnector + httpClient *http.Client + oauthService oauthtoken.OAuthTokenService } func (c *OAuth) Name() string { @@ -204,6 +210,29 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir }, nil } +func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) { + token := c.oauthService.GetCurrentOAuthToken(ctx, user) + + if err := c.oauthService.InvalidateOAuthTokens(ctx, info); err != nil { + namespace, id := user.GetNamespacedID() + c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err) + } + + redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg) + if redirctURL == "" { + c.log.FromContext(ctx).Debug("No signout redirect url configured") + return nil, false + } + + if isOICDLogout(redirctURL) && token != nil && token.Valid() { + if idToken, ok := token.Extra("id_token").(string); ok { + redirctURL = withIDTokenHint(redirctURL, idToken) + } + } + + return &authn.Redirect{URL: redirctURL}, true +} + // genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest. func genPKCECode() (string, string, error) { // IETF RFC 7636 specifies that the code verifier should be 43-128 @@ -243,3 +272,43 @@ func hashOAuthState(state, secret, seed string) string { hashBytes := sha256.Sum256([]byte(state + secret + seed)) return hex.EncodeToString(hashBytes[:]) } + +func getOAuthSignoutRedirectURL(cfg *setting.Cfg, oauthCfg *social.OAuthInfo) string { + if oauthCfg.SignoutRedirectUrl != "" { + return oauthCfg.SignoutRedirectUrl + } + + return cfg.SignoutRedirectUrl +} + +func withIDTokenHint(redirectURL string, idToken string) string { + if idToken == "" { + return redirectURL + } + + u, err := url.Parse(redirectURL) + if err != nil { + return redirectURL + } + + q := u.Query() + q.Set("id_token_hint", idToken) + u.RawQuery = q.Encode() + + return u.String() +} + +func isOICDLogout(redirectUrl string) bool { + if redirectUrl == "" { + return false + } + + u, err := url.Parse(redirectUrl) + if err != nil { + return false + } + + q := u.Query() + _, ok := q[oauthPostLogoutRedirectParam] + return ok +} diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index 30da79459c0..893bc6934f9 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -4,7 +4,9 @@ import ( "context" "net/http" "net/url" + "strings" "testing" + "time" "golang.org/x/oauth2" @@ -12,8 +14,10 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/setting" ) @@ -212,7 +216,7 @@ func TestOAuth_Authenticate(t *testing.T) { ExpectedToken: &oauth2.Token{}, ExpectedIsSignupAllowed: true, ExpectedIsEmailAllowed: tt.isEmailAllowed, - }, nil) + }, nil, nil) identity, err := c.Authenticate(context.Background(), tt.req) assert.ErrorIs(t, err, tt.expectedErr) @@ -281,7 +285,7 @@ func TestOAuth_RedirectURL(t *testing.T) { require.Len(t, opts, tt.numCallOptions) return "" }, - }, nil) + }, nil, nil) redirect, err := c.RedirectURL(context.Background(), nil) assert.ErrorIs(t, err, tt.expectedErr) @@ -299,6 +303,107 @@ func TestOAuth_RedirectURL(t *testing.T) { } } +func TestOAuth_Logout(t *testing.T) { + type testCase struct { + desc string + cfg *setting.Cfg + oauthCfg *social.OAuthInfo + + expectedOK bool + expectedURL string + expectedIDTokenHint string + expectedPostLogoutURI string + } + + tests := []testCase{ + { + desc: "should not return redirect url if not configured for client or globably", + cfg: &setting.Cfg{}, + oauthCfg: &social.OAuthInfo{}, + }, + { + desc: "should return redirect url for globably configured redirect url", + cfg: &setting.Cfg{ + SignoutRedirectUrl: "http://idp.com/logout", + }, + oauthCfg: &social.OAuthInfo{}, + expectedURL: "http://idp.com/logout", + expectedOK: true, + }, + { + desc: "should return redirect url for client configured redirect url", + cfg: &setting.Cfg{}, + oauthCfg: &social.OAuthInfo{ + SignoutRedirectUrl: "http://idp.com/logout", + }, + expectedURL: "http://idp.com/logout", + expectedOK: true, + }, + { + desc: "client specific url should take precedence", + cfg: &setting.Cfg{ + SignoutRedirectUrl: "http://idp.com/logout", + }, + oauthCfg: &social.OAuthInfo{ + SignoutRedirectUrl: "http://idp-2.com/logout", + }, + expectedURL: "http://idp-2.com/logout", + expectedOK: true, + }, + { + desc: "should add id token hint if oicd logout is configured and token is valid", + cfg: &setting.Cfg{}, + oauthCfg: &social.OAuthInfo{ + SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin", + }, + expectedURL: "http://idp.com/logout", + expectedIDTokenHint: "id_token_hint=some.id.token", + expectedPostLogoutURI: "http%3A%3A%2F%2Ftest.com%2Flogin", + expectedOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var ( + getTokenCalled bool + invalidateTokenCalled bool + ) + + mockService := &oauthtokentest.MockOauthTokenService{ + GetCurrentOauthTokenFunc: func(_ context.Context, _ identity.Requester) *oauth2.Token { + getTokenCalled = true + token := &oauth2.Token{ + AccessToken: "some.access.token", + Expiry: time.Now().Add(10 * time.Minute), + } + return token.WithExtra(map[string]any{ + "id_token": "some.id.token", + }) + }, + InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error { + invalidateTokenCalled = true + return nil + }, + } + + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService) + + redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{}) + + assert.Equal(t, tt.expectedOK, ok) + if tt.expectedOK { + assert.True(t, strings.HasPrefix(redirect.URL, tt.expectedURL)) + assert.Contains(t, redirect.URL, tt.expectedIDTokenHint) + assert.Contains(t, redirect.URL, tt.expectedPostLogoutURI) + } + + assert.True(t, getTokenCalled) + assert.True(t, invalidateTokenCalled) + }) + } +} + type mockConnector struct { AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string social.SocialConnector