mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Authn: Handle logout logic in auth broker (#79635)
* AuthN: Add new client extension interface that allows for custom logout logic * AuthN: Add tests for oauth client logout * Call authn.Logout Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
This commit is contained in:
parent
eb490193b9
commit
8cb351e54a
119
pkg/api/login.go
119
pkg/api/login.go
@ -29,8 +29,6 @@ import (
|
||||
const (
|
||||
viewIndex = "index"
|
||||
loginErrorCookieName = "login_error"
|
||||
// #nosec G101 - this is not a hardcoded secret
|
||||
postLogoutRedirectParam = "post_logout_redirect_uri"
|
||||
)
|
||||
|
||||
var setIndexViewData = (*HTTPServer).setIndexViewData
|
||||
@ -243,70 +241,31 @@ func (hs *HTTPServer) loginUserWithUser(user *user.User, c *contextmodel.ReqCont
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
|
||||
userID, errID := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
|
||||
if errID != nil {
|
||||
hs.log.Error("failed to retrieve user ID", "error", errID)
|
||||
}
|
||||
|
||||
oauthProviderSignoutRedirectUrl := ""
|
||||
getAuthQuery := loginservice.GetAuthInfoQuery{UserId: userID}
|
||||
authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery)
|
||||
if err == nil {
|
||||
// If SAML is enabled and this is a SAML user use saml logout
|
||||
if hs.samlSingleLogoutEnabled() {
|
||||
if authInfo.AuthModule == loginservice.SAMLAuthModule {
|
||||
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
|
||||
return
|
||||
}
|
||||
// FIXME: restructure saml client to implement authn.LogoutClient
|
||||
if hs.samlSingleLogoutEnabled() {
|
||||
id, err := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
|
||||
if err != nil {
|
||||
hs.log.Error("failed to retrieve user ID", "error", err)
|
||||
}
|
||||
oauthProvider := hs.SocialService.GetOAuthInfoProvider(strings.TrimPrefix(authInfo.AuthModule, "oauth_"))
|
||||
if oauthProvider != nil {
|
||||
oauthProviderSignoutRedirectUrl = oauthProvider.SignoutRedirectUrl
|
||||
|
||||
authInfo, _ := hs.authInfoService.GetAuthInfo(c.Req.Context(), &loginservice.GetAuthInfoQuery{UserId: id})
|
||||
if authInfo != nil && authInfo.AuthModule == loginservice.SAMLAuthModule {
|
||||
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hs.log.Debug("Logout Redirect url", "auth.SignoutRedirectUrl:", hs.Cfg.SignoutRedirectUrl)
|
||||
hs.log.Debug("Logout Redirect url", "oauth provider redirect url:", oauthProviderSignoutRedirectUrl)
|
||||
|
||||
signOutRedirectUrl := getSignOutRedirectUrl(hs.Cfg.SignoutRedirectUrl, oauthProviderSignoutRedirectUrl)
|
||||
|
||||
hs.log.Debug("Logout Redirect url", "signOurRedirectUrl:", signOutRedirectUrl)
|
||||
idTokenHint := ""
|
||||
oidcLogout := isPostLogoutRedirectConfigured(signOutRedirectUrl)
|
||||
|
||||
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
|
||||
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
|
||||
token := hs.oauthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser)
|
||||
if oidcLogout {
|
||||
if token.Valid() {
|
||||
idTokenHint = token.Extra("id_token").(string)
|
||||
} else {
|
||||
hs.log.Warn("Token is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
|
||||
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", userID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
|
||||
if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) {
|
||||
hs.log.Error("failed to revoke auth token", "error", err)
|
||||
}
|
||||
|
||||
redirect, err := hs.authnService.Logout(c.Req.Context(), c.SignedInUser, c.UserToken)
|
||||
authn.DeleteSessionCookie(c.Resp, hs.Cfg)
|
||||
|
||||
rdUrl := signOutRedirectUrl
|
||||
if rdUrl != "" {
|
||||
if oidcLogout {
|
||||
rdUrl = getPostRedirectUrl(signOutRedirectUrl, idTokenHint)
|
||||
}
|
||||
c.Redirect(rdUrl)
|
||||
} else {
|
||||
hs.log.Info("Successful Logout", "User", c.SignedInUser.GetEmail())
|
||||
if err != nil {
|
||||
hs.log.Error("Failed perform proper logout", "error", err)
|
||||
c.Redirect(hs.Cfg.AppSubURL + "/login")
|
||||
}
|
||||
|
||||
_, id := c.SignedInUser.GetNamespacedID()
|
||||
hs.log.Info("Successful Logout", "userID", id)
|
||||
c.Redirect(redirect.URL)
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) tryGetEncryptedCookie(ctx *contextmodel.ReqContext, cookieName string) (string, bool) {
|
||||
@ -420,47 +379,3 @@ func getFirstPublicErrorMessage(err *errutil.Error) string {
|
||||
|
||||
return errPublic.Message
|
||||
}
|
||||
|
||||
func isPostLogoutRedirectConfigured(redirectUrl string) bool {
|
||||
if redirectUrl == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirectUrl)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
_, ok := q[postLogoutRedirectParam]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getPostRedirectUrl(rdUrl string, tokenHint string) string {
|
||||
if tokenHint == "" {
|
||||
return rdUrl
|
||||
}
|
||||
if rdUrl == "" {
|
||||
return rdUrl
|
||||
}
|
||||
|
||||
u, err := url.Parse(rdUrl)
|
||||
if err != nil {
|
||||
return rdUrl
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", tokenHint)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func getSignOutRedirectUrl(gRdUrl string, oauthProviderUrl string) string {
|
||||
if oauthProviderUrl != "" {
|
||||
return oauthProviderUrl
|
||||
} else if gRdUrl != "" {
|
||||
return gRdUrl
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/api/response"
|
||||
"github.com/grafana/grafana/pkg/middleware/cookies"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
@ -74,6 +75,8 @@ type Service interface {
|
||||
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
|
||||
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
|
||||
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
|
||||
// Logout revokes session token and does additional clean up if client used to authenticate supports it
|
||||
Logout(ctx context.Context, user identity.Requester, sessionToken *usertoken.UserToken) (*Redirect, error)
|
||||
// RegisterClient will register a new authn.Client that can be used for authentication
|
||||
RegisterClient(c Client)
|
||||
}
|
||||
@ -115,6 +118,14 @@ type RedirectClient interface {
|
||||
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
|
||||
}
|
||||
|
||||
// LogoutCLient is an optional interface that auth client can implement.
|
||||
// Clients that implements this interface can implement additional logic
|
||||
// that should happen during logout and supports client specific redirect URL.
|
||||
type LogoutClient interface {
|
||||
Client
|
||||
Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*Redirect, bool)
|
||||
}
|
||||
|
||||
type PasswordClient interface {
|
||||
AuthenticatePassword(ctx context.Context, r *Request, username, password string) (*Identity, error)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
|
||||
"github.com/grafana/grafana/pkg/services/authn/clients"
|
||||
@ -73,15 +75,16 @@ func ProvideService(
|
||||
signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server,
|
||||
) *Service {
|
||||
s := &Service{
|
||||
log: log.New("authn.service"),
|
||||
cfg: cfg,
|
||||
clients: make(map[string]authn.Client),
|
||||
clientQueue: newQueue[authn.ContextAwareClient](),
|
||||
tracer: tracer,
|
||||
metrics: newMetrics(registerer),
|
||||
sessionService: sessionService,
|
||||
postAuthHooks: newQueue[authn.PostAuthHookFn](),
|
||||
postLoginHooks: newQueue[authn.PostLoginHookFn](),
|
||||
log: log.New("authn.service"),
|
||||
cfg: cfg,
|
||||
clients: make(map[string]authn.Client),
|
||||
clientQueue: newQueue[authn.ContextAwareClient](),
|
||||
tracer: tracer,
|
||||
metrics: newMetrics(registerer),
|
||||
authInfoService: authInfoService,
|
||||
sessionService: sessionService,
|
||||
postAuthHooks: newQueue[authn.PostAuthHookFn](),
|
||||
postLoginHooks: newQueue[authn.PostLoginHookFn](),
|
||||
}
|
||||
|
||||
usageStats.RegisterMetricsFunc(s.getUsageStats)
|
||||
@ -146,7 +149,7 @@ func ProvideService(
|
||||
if errConnector != nil || errHTTPClient != nil {
|
||||
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
|
||||
} else {
|
||||
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient))
|
||||
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -175,7 +178,8 @@ type Service struct {
|
||||
tracer tracing.Tracer
|
||||
metrics *metrics
|
||||
|
||||
sessionService auth.UserTokenService
|
||||
authInfoService login.AuthInfoService
|
||||
sessionService auth.UserTokenService
|
||||
|
||||
// postAuthHooks are called after a successful authentication. They can modify the identity.
|
||||
postAuthHooks *queue[authn.PostAuthHookFn]
|
||||
@ -335,6 +339,55 @@ func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Reque
|
||||
return redirectClient.RedirectURL(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionToken *auth.UserToken) (*authn.Redirect, error) {
|
||||
ctx, span := s.tracer.Start(ctx, "authn.Logout")
|
||||
defer span.End()
|
||||
|
||||
redirect := &authn.Redirect{URL: s.cfg.AppSubURL + "/login"}
|
||||
|
||||
namespace, id := user.GetNamespacedID()
|
||||
if namespace != authn.NamespaceUser {
|
||||
return redirect, nil
|
||||
}
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Debug("Invalid user id", "id", userID, "err", err)
|
||||
return redirect, nil
|
||||
}
|
||||
|
||||
info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID})
|
||||
if info != nil {
|
||||
client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_"))
|
||||
|
||||
c, ok := s.clients[client]
|
||||
if !ok {
|
||||
s.log.FromContext(ctx).Debug("No client configured for auth module", "client", client)
|
||||
goto Default
|
||||
}
|
||||
|
||||
logoutClient, ok := c.(authn.LogoutClient)
|
||||
if !ok {
|
||||
s.log.FromContext(ctx).Debug("Client do not support specialized logout logic", "client", client)
|
||||
goto Default
|
||||
}
|
||||
|
||||
clientRedirect, ok := logoutClient.Logout(ctx, user, info)
|
||||
if !ok {
|
||||
goto Default
|
||||
}
|
||||
|
||||
redirect = clientRedirect
|
||||
}
|
||||
|
||||
Default:
|
||||
if err = s.sessionService.RevokeToken(ctx, sessionToken, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return redirect, nil
|
||||
}
|
||||
|
||||
func (s *Service) RegisterClient(c authn.Client) {
|
||||
s.clients[c.Name()] = c
|
||||
if cac, ok := c.(authn.ContextAwareClient); ok {
|
||||
|
@ -13,10 +13,14 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfotest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
@ -299,6 +303,95 @@ func TestService_RedirectURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Logout(t *testing.T) {
|
||||
type TestCase struct {
|
||||
desc string
|
||||
|
||||
identity *authn.Identity
|
||||
sessionToken *usertoken.UserToken
|
||||
info *login.UserAuth
|
||||
|
||||
client authn.Client
|
||||
|
||||
expectedErr error
|
||||
expectedTokenRevoked bool
|
||||
expectedRedirect *authn.Redirect
|
||||
}
|
||||
|
||||
tests := []TestCase{
|
||||
{
|
||||
desc: "should redirect to default redirect url when identity is not a user",
|
||||
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceServiceAccount, 1)},
|
||||
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||
},
|
||||
{
|
||||
desc: "should redirect to default redirect url when no external provider was used to authenticate",
|
||||
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||
expectedTokenRevoked: true,
|
||||
},
|
||||
{
|
||||
desc: "should redirect to default redirect url when client is not found",
|
||||
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||
info: &login.UserAuth{AuthModule: "notFound"},
|
||||
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||
expectedTokenRevoked: true,
|
||||
},
|
||||
{
|
||||
desc: "should redirect to default redirect url when client do not implement logout extension",
|
||||
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||
info: &login.UserAuth{AuthModule: "azuread"},
|
||||
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||
client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"},
|
||||
expectedTokenRevoked: true,
|
||||
},
|
||||
{
|
||||
desc: "should redirect to client specific url",
|
||||
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||
info: &login.UserAuth{AuthModule: "azuread"},
|
||||
expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"},
|
||||
client: &authntest.MockClient{
|
||||
NameFunc: func() string { return "auth.client.azuread" },
|
||||
LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) {
|
||||
return &authn.Redirect{URL: "http://idp.com/logout"}, true
|
||||
},
|
||||
},
|
||||
expectedTokenRevoked: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var tokenRevoked bool
|
||||
|
||||
s := setupTests(t, func(svc *Service) {
|
||||
if tt.client != nil {
|
||||
svc.RegisterClient(tt.client)
|
||||
}
|
||||
svc.cfg.AppSubURL = "http://localhost:3000"
|
||||
svc.authInfoService = &authinfotest.FakeService{
|
||||
ExpectedUserAuth: tt.info,
|
||||
}
|
||||
|
||||
svc.sessionService = &authtest.FakeUserAuthTokenService{
|
||||
RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error {
|
||||
tokenRevoked = true
|
||||
assert.EqualValues(t, tt.sessionToken, sessionToken)
|
||||
assert.False(t, soft)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
redirect, err := s.Logout(context.Background(), tt.identity, tt.sessionToken)
|
||||
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.EqualValues(t, tt.expectedRedirect, redirect)
|
||||
assert.Equal(t, tt.expectedTokenRevoked, tokenRevoked)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
|
@ -3,6 +3,8 @@ package authntest
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
)
|
||||
|
||||
@ -66,6 +68,10 @@ func (f *FakeService) RedirectURL(ctx context.Context, client string, r *authn.R
|
||||
return f.ExpectedRedirect, f.ExpectedErr
|
||||
}
|
||||
|
||||
func (*FakeService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeService) RegisterClient(c authn.Client) {}
|
||||
|
||||
func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
||||
|
@ -3,7 +3,10 @@ package authntest
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
)
|
||||
|
||||
var _ authn.Service = new(MockService)
|
||||
@ -40,6 +43,10 @@ func (m *MockService) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*MockService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
||||
if m.SyncIdentityFunc != nil {
|
||||
return m.SyncIdentityFunc(ctx, identity)
|
||||
@ -48,6 +55,7 @@ func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity
|
||||
}
|
||||
|
||||
var _ authn.HookClient = new(MockClient)
|
||||
var _ authn.LogoutClient = new(MockClient)
|
||||
var _ authn.ContextAwareClient = new(MockClient)
|
||||
|
||||
type MockClient struct {
|
||||
@ -56,6 +64,7 @@ type MockClient struct {
|
||||
TestFunc func(ctx context.Context, r *authn.Request) bool
|
||||
PriorityFunc func() uint
|
||||
HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error
|
||||
LogoutFunc func(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool)
|
||||
}
|
||||
|
||||
func (m MockClient) Name() string {
|
||||
@ -93,6 +102,13 @@ func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockClient) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
|
||||
if m.LogoutFunc != nil {
|
||||
return m.LogoutFunc(ctx, user, info)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var _ authn.ProxyClient = new(MockProxyClient)
|
||||
|
||||
type MockProxyClient struct {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
@ -16,8 +17,10 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/social/connectors"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
@ -30,9 +33,10 @@ const (
|
||||
codeChallengeMethodParamName = "code_challenge_method"
|
||||
codeChallengeMethod = "S256"
|
||||
|
||||
oauthStateQueryName = "state"
|
||||
oauthStateCookieName = "oauth_state"
|
||||
oauthPKCECookieName = "oauth_code_verifier"
|
||||
oauthStateQueryName = "state"
|
||||
oauthStateCookieName = "oauth_state"
|
||||
oauthPKCECookieName = "oauth_code_verifier"
|
||||
oauthPostLogoutRedirectParam = "post_logout_redirect_uri"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -54,26 +58,28 @@ func fromSocialErr(err *connectors.SocialError) error {
|
||||
return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err)
|
||||
}
|
||||
|
||||
var _ authn.LogoutClient = new(OAuth)
|
||||
var _ authn.RedirectClient = new(OAuth)
|
||||
|
||||
func ProvideOAuth(
|
||||
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
||||
connector social.SocialConnector, httpClient *http.Client,
|
||||
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
|
||||
) *OAuth {
|
||||
return &OAuth{
|
||||
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
||||
log.New(name), cfg, oauthCfg, connector, httpClient,
|
||||
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
type OAuth struct {
|
||||
name string
|
||||
moduleName string
|
||||
log log.Logger
|
||||
cfg *setting.Cfg
|
||||
oauthCfg *social.OAuthInfo
|
||||
connector social.SocialConnector
|
||||
httpClient *http.Client
|
||||
name string
|
||||
moduleName string
|
||||
log log.Logger
|
||||
cfg *setting.Cfg
|
||||
oauthCfg *social.OAuthInfo
|
||||
connector social.SocialConnector
|
||||
httpClient *http.Client
|
||||
oauthService oauthtoken.OAuthTokenService
|
||||
}
|
||||
|
||||
func (c *OAuth) Name() string {
|
||||
@ -204,6 +210,29 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
|
||||
token := c.oauthService.GetCurrentOAuthToken(ctx, user)
|
||||
|
||||
if err := c.oauthService.InvalidateOAuthTokens(ctx, info); err != nil {
|
||||
namespace, id := user.GetNamespacedID()
|
||||
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
|
||||
}
|
||||
|
||||
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
|
||||
if redirctURL == "" {
|
||||
c.log.FromContext(ctx).Debug("No signout redirect url configured")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
|
||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||
redirctURL = withIDTokenHint(redirctURL, idToken)
|
||||
}
|
||||
}
|
||||
|
||||
return &authn.Redirect{URL: redirctURL}, true
|
||||
}
|
||||
|
||||
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
||||
func genPKCECode() (string, string, error) {
|
||||
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
||||
@ -243,3 +272,43 @@ func hashOAuthState(state, secret, seed string) string {
|
||||
hashBytes := sha256.Sum256([]byte(state + secret + seed))
|
||||
return hex.EncodeToString(hashBytes[:])
|
||||
}
|
||||
|
||||
func getOAuthSignoutRedirectURL(cfg *setting.Cfg, oauthCfg *social.OAuthInfo) string {
|
||||
if oauthCfg.SignoutRedirectUrl != "" {
|
||||
return oauthCfg.SignoutRedirectUrl
|
||||
}
|
||||
|
||||
return cfg.SignoutRedirectUrl
|
||||
}
|
||||
|
||||
func withIDTokenHint(redirectURL string, idToken string) string {
|
||||
if idToken == "" {
|
||||
return redirectURL
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
return redirectURL
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func isOICDLogout(redirectUrl string) bool {
|
||||
if redirectUrl == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirectUrl)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
_, ok := q[oauthPostLogoutRedirectParam]
|
||||
return ok
|
||||
}
|
||||
|
@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
@ -12,8 +14,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
@ -212,7 +216,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
||||
ExpectedToken: &oauth2.Token{},
|
||||
ExpectedIsSignupAllowed: true,
|
||||
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
identity, err := c.Authenticate(context.Background(), tt.req)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
|
||||
@ -281,7 +285,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
||||
require.Len(t, opts, tt.numCallOptions)
|
||||
return ""
|
||||
},
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
redirect, err := c.RedirectURL(context.Background(), nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
@ -299,6 +303,107 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth_Logout(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
cfg *setting.Cfg
|
||||
oauthCfg *social.OAuthInfo
|
||||
|
||||
expectedOK bool
|
||||
expectedURL string
|
||||
expectedIDTokenHint string
|
||||
expectedPostLogoutURI string
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should not return redirect url if not configured for client or globably",
|
||||
cfg: &setting.Cfg{},
|
||||
oauthCfg: &social.OAuthInfo{},
|
||||
},
|
||||
{
|
||||
desc: "should return redirect url for globably configured redirect url",
|
||||
cfg: &setting.Cfg{
|
||||
SignoutRedirectUrl: "http://idp.com/logout",
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{},
|
||||
expectedURL: "http://idp.com/logout",
|
||||
expectedOK: true,
|
||||
},
|
||||
{
|
||||
desc: "should return redirect url for client configured redirect url",
|
||||
cfg: &setting.Cfg{},
|
||||
oauthCfg: &social.OAuthInfo{
|
||||
SignoutRedirectUrl: "http://idp.com/logout",
|
||||
},
|
||||
expectedURL: "http://idp.com/logout",
|
||||
expectedOK: true,
|
||||
},
|
||||
{
|
||||
desc: "client specific url should take precedence",
|
||||
cfg: &setting.Cfg{
|
||||
SignoutRedirectUrl: "http://idp.com/logout",
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{
|
||||
SignoutRedirectUrl: "http://idp-2.com/logout",
|
||||
},
|
||||
expectedURL: "http://idp-2.com/logout",
|
||||
expectedOK: true,
|
||||
},
|
||||
{
|
||||
desc: "should add id token hint if oicd logout is configured and token is valid",
|
||||
cfg: &setting.Cfg{},
|
||||
oauthCfg: &social.OAuthInfo{
|
||||
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
|
||||
},
|
||||
expectedURL: "http://idp.com/logout",
|
||||
expectedIDTokenHint: "id_token_hint=some.id.token",
|
||||
expectedPostLogoutURI: "http%3A%3A%2F%2Ftest.com%2Flogin",
|
||||
expectedOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var (
|
||||
getTokenCalled bool
|
||||
invalidateTokenCalled bool
|
||||
)
|
||||
|
||||
mockService := &oauthtokentest.MockOauthTokenService{
|
||||
GetCurrentOauthTokenFunc: func(_ context.Context, _ identity.Requester) *oauth2.Token {
|
||||
getTokenCalled = true
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "some.access.token",
|
||||
Expiry: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
return token.WithExtra(map[string]any{
|
||||
"id_token": "some.id.token",
|
||||
})
|
||||
},
|
||||
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||
invalidateTokenCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
|
||||
|
||||
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})
|
||||
|
||||
assert.Equal(t, tt.expectedOK, ok)
|
||||
if tt.expectedOK {
|
||||
assert.True(t, strings.HasPrefix(redirect.URL, tt.expectedURL))
|
||||
assert.Contains(t, redirect.URL, tt.expectedIDTokenHint)
|
||||
assert.Contains(t, redirect.URL, tt.expectedPostLogoutURI)
|
||||
}
|
||||
|
||||
assert.True(t, getTokenCalled)
|
||||
assert.True(t, invalidateTokenCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockConnector struct {
|
||||
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
|
||||
social.SocialConnector
|
||||
|
Loading…
Reference in New Issue
Block a user