mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Authn: Handle logout logic in auth broker (#79635)
* AuthN: Add new client extension interface that allows for custom logout logic * AuthN: Add tests for oauth client logout * Call authn.Logout Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
This commit is contained in:
parent
eb490193b9
commit
8cb351e54a
119
pkg/api/login.go
119
pkg/api/login.go
@ -29,8 +29,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
viewIndex = "index"
|
viewIndex = "index"
|
||||||
loginErrorCookieName = "login_error"
|
loginErrorCookieName = "login_error"
|
||||||
// #nosec G101 - this is not a hardcoded secret
|
|
||||||
postLogoutRedirectParam = "post_logout_redirect_uri"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var setIndexViewData = (*HTTPServer).setIndexViewData
|
var setIndexViewData = (*HTTPServer).setIndexViewData
|
||||||
@ -243,70 +241,31 @@ func (hs *HTTPServer) loginUserWithUser(user *user.User, c *contextmodel.ReqCont
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
|
func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
|
||||||
userID, errID := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
|
// FIXME: restructure saml client to implement authn.LogoutClient
|
||||||
if errID != nil {
|
if hs.samlSingleLogoutEnabled() {
|
||||||
hs.log.Error("failed to retrieve user ID", "error", errID)
|
id, err := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
|
||||||
}
|
if err != nil {
|
||||||
|
hs.log.Error("failed to retrieve user ID", "error", err)
|
||||||
oauthProviderSignoutRedirectUrl := ""
|
|
||||||
getAuthQuery := loginservice.GetAuthInfoQuery{UserId: userID}
|
|
||||||
authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery)
|
|
||||||
if err == nil {
|
|
||||||
// If SAML is enabled and this is a SAML user use saml logout
|
|
||||||
if hs.samlSingleLogoutEnabled() {
|
|
||||||
if authInfo.AuthModule == loginservice.SAMLAuthModule {
|
|
||||||
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
oauthProvider := hs.SocialService.GetOAuthInfoProvider(strings.TrimPrefix(authInfo.AuthModule, "oauth_"))
|
|
||||||
if oauthProvider != nil {
|
authInfo, _ := hs.authInfoService.GetAuthInfo(c.Req.Context(), &loginservice.GetAuthInfoQuery{UserId: id})
|
||||||
oauthProviderSignoutRedirectUrl = oauthProvider.SignoutRedirectUrl
|
if authInfo != nil && authInfo.AuthModule == loginservice.SAMLAuthModule {
|
||||||
|
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.log.Debug("Logout Redirect url", "auth.SignoutRedirectUrl:", hs.Cfg.SignoutRedirectUrl)
|
redirect, err := hs.authnService.Logout(c.Req.Context(), c.SignedInUser, c.UserToken)
|
||||||
hs.log.Debug("Logout Redirect url", "oauth provider redirect url:", oauthProviderSignoutRedirectUrl)
|
|
||||||
|
|
||||||
signOutRedirectUrl := getSignOutRedirectUrl(hs.Cfg.SignoutRedirectUrl, oauthProviderSignoutRedirectUrl)
|
|
||||||
|
|
||||||
hs.log.Debug("Logout Redirect url", "signOurRedirectUrl:", signOutRedirectUrl)
|
|
||||||
idTokenHint := ""
|
|
||||||
oidcLogout := isPostLogoutRedirectConfigured(signOutRedirectUrl)
|
|
||||||
|
|
||||||
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
|
|
||||||
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
|
|
||||||
token := hs.oauthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser)
|
|
||||||
if oidcLogout {
|
|
||||||
if token.Valid() {
|
|
||||||
idTokenHint = token.Extra("id_token").(string)
|
|
||||||
} else {
|
|
||||||
hs.log.Warn("Token is not valid")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
|
|
||||||
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", userID, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
|
|
||||||
if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) {
|
|
||||||
hs.log.Error("failed to revoke auth token", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
authn.DeleteSessionCookie(c.Resp, hs.Cfg)
|
authn.DeleteSessionCookie(c.Resp, hs.Cfg)
|
||||||
|
|
||||||
rdUrl := signOutRedirectUrl
|
if err != nil {
|
||||||
if rdUrl != "" {
|
hs.log.Error("Failed perform proper logout", "error", err)
|
||||||
if oidcLogout {
|
|
||||||
rdUrl = getPostRedirectUrl(signOutRedirectUrl, idTokenHint)
|
|
||||||
}
|
|
||||||
c.Redirect(rdUrl)
|
|
||||||
} else {
|
|
||||||
hs.log.Info("Successful Logout", "User", c.SignedInUser.GetEmail())
|
|
||||||
c.Redirect(hs.Cfg.AppSubURL + "/login")
|
c.Redirect(hs.Cfg.AppSubURL + "/login")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, id := c.SignedInUser.GetNamespacedID()
|
||||||
|
hs.log.Info("Successful Logout", "userID", id)
|
||||||
|
c.Redirect(redirect.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *HTTPServer) tryGetEncryptedCookie(ctx *contextmodel.ReqContext, cookieName string) (string, bool) {
|
func (hs *HTTPServer) tryGetEncryptedCookie(ctx *contextmodel.ReqContext, cookieName string) (string, bool) {
|
||||||
@ -420,47 +379,3 @@ func getFirstPublicErrorMessage(err *errutil.Error) string {
|
|||||||
|
|
||||||
return errPublic.Message
|
return errPublic.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPostLogoutRedirectConfigured(redirectUrl string) bool {
|
|
||||||
if redirectUrl == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(redirectUrl)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
q := u.Query()
|
|
||||||
_, ok := q[postLogoutRedirectParam]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPostRedirectUrl(rdUrl string, tokenHint string) string {
|
|
||||||
if tokenHint == "" {
|
|
||||||
return rdUrl
|
|
||||||
}
|
|
||||||
if rdUrl == "" {
|
|
||||||
return rdUrl
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(rdUrl)
|
|
||||||
if err != nil {
|
|
||||||
return rdUrl
|
|
||||||
}
|
|
||||||
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("id_token_hint", tokenHint)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSignOutRedirectUrl(gRdUrl string, oauthProviderUrl string) string {
|
|
||||||
if oauthProviderUrl != "" {
|
|
||||||
return oauthProviderUrl
|
|
||||||
} else if gRdUrl != "" {
|
|
||||||
return gRdUrl
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/api/response"
|
"github.com/grafana/grafana/pkg/api/response"
|
||||||
"github.com/grafana/grafana/pkg/middleware/cookies"
|
"github.com/grafana/grafana/pkg/middleware/cookies"
|
||||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
@ -74,6 +75,8 @@ type Service interface {
|
|||||||
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
|
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
|
||||||
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
|
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
|
||||||
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
|
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
|
||||||
|
// Logout revokes session token and does additional clean up if client used to authenticate supports it
|
||||||
|
Logout(ctx context.Context, user identity.Requester, sessionToken *usertoken.UserToken) (*Redirect, error)
|
||||||
// RegisterClient will register a new authn.Client that can be used for authentication
|
// RegisterClient will register a new authn.Client that can be used for authentication
|
||||||
RegisterClient(c Client)
|
RegisterClient(c Client)
|
||||||
}
|
}
|
||||||
@ -115,6 +118,14 @@ type RedirectClient interface {
|
|||||||
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
|
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogoutCLient is an optional interface that auth client can implement.
|
||||||
|
// Clients that implements this interface can implement additional logic
|
||||||
|
// that should happen during logout and supports client specific redirect URL.
|
||||||
|
type LogoutClient interface {
|
||||||
|
Client
|
||||||
|
Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*Redirect, bool)
|
||||||
|
}
|
||||||
|
|
||||||
type PasswordClient interface {
|
type PasswordClient interface {
|
||||||
AuthenticatePassword(ctx context.Context, r *Request, username, password string) (*Identity, error)
|
AuthenticatePassword(ctx context.Context, r *Request, username, password string) (*Identity, error)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
@ -19,6 +20,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||||
"github.com/grafana/grafana/pkg/services/apikey"
|
"github.com/grafana/grafana/pkg/services/apikey"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
|
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
|
||||||
"github.com/grafana/grafana/pkg/services/authn/clients"
|
"github.com/grafana/grafana/pkg/services/authn/clients"
|
||||||
@ -73,15 +75,16 @@ func ProvideService(
|
|||||||
signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server,
|
signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server,
|
||||||
) *Service {
|
) *Service {
|
||||||
s := &Service{
|
s := &Service{
|
||||||
log: log.New("authn.service"),
|
log: log.New("authn.service"),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
clients: make(map[string]authn.Client),
|
clients: make(map[string]authn.Client),
|
||||||
clientQueue: newQueue[authn.ContextAwareClient](),
|
clientQueue: newQueue[authn.ContextAwareClient](),
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
metrics: newMetrics(registerer),
|
metrics: newMetrics(registerer),
|
||||||
sessionService: sessionService,
|
authInfoService: authInfoService,
|
||||||
postAuthHooks: newQueue[authn.PostAuthHookFn](),
|
sessionService: sessionService,
|
||||||
postLoginHooks: newQueue[authn.PostLoginHookFn](),
|
postAuthHooks: newQueue[authn.PostAuthHookFn](),
|
||||||
|
postLoginHooks: newQueue[authn.PostLoginHookFn](),
|
||||||
}
|
}
|
||||||
|
|
||||||
usageStats.RegisterMetricsFunc(s.getUsageStats)
|
usageStats.RegisterMetricsFunc(s.getUsageStats)
|
||||||
@ -146,7 +149,7 @@ func ProvideService(
|
|||||||
if errConnector != nil || errHTTPClient != nil {
|
if errConnector != nil || errHTTPClient != nil {
|
||||||
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
|
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
|
||||||
} else {
|
} else {
|
||||||
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient))
|
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,7 +178,8 @@ type Service struct {
|
|||||||
tracer tracing.Tracer
|
tracer tracing.Tracer
|
||||||
metrics *metrics
|
metrics *metrics
|
||||||
|
|
||||||
sessionService auth.UserTokenService
|
authInfoService login.AuthInfoService
|
||||||
|
sessionService auth.UserTokenService
|
||||||
|
|
||||||
// postAuthHooks are called after a successful authentication. They can modify the identity.
|
// postAuthHooks are called after a successful authentication. They can modify the identity.
|
||||||
postAuthHooks *queue[authn.PostAuthHookFn]
|
postAuthHooks *queue[authn.PostAuthHookFn]
|
||||||
@ -335,6 +339,55 @@ func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Reque
|
|||||||
return redirectClient.RedirectURL(ctx, r)
|
return redirectClient.RedirectURL(ctx, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionToken *auth.UserToken) (*authn.Redirect, error) {
|
||||||
|
ctx, span := s.tracer.Start(ctx, "authn.Logout")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
redirect := &authn.Redirect{URL: s.cfg.AppSubURL + "/login"}
|
||||||
|
|
||||||
|
namespace, id := user.GetNamespacedID()
|
||||||
|
if namespace != authn.NamespaceUser {
|
||||||
|
return redirect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := identity.IntIdentifier(namespace, id)
|
||||||
|
if err != nil {
|
||||||
|
s.log.FromContext(ctx).Debug("Invalid user id", "id", userID, "err", err)
|
||||||
|
return redirect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID})
|
||||||
|
if info != nil {
|
||||||
|
client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_"))
|
||||||
|
|
||||||
|
c, ok := s.clients[client]
|
||||||
|
if !ok {
|
||||||
|
s.log.FromContext(ctx).Debug("No client configured for auth module", "client", client)
|
||||||
|
goto Default
|
||||||
|
}
|
||||||
|
|
||||||
|
logoutClient, ok := c.(authn.LogoutClient)
|
||||||
|
if !ok {
|
||||||
|
s.log.FromContext(ctx).Debug("Client do not support specialized logout logic", "client", client)
|
||||||
|
goto Default
|
||||||
|
}
|
||||||
|
|
||||||
|
clientRedirect, ok := logoutClient.Logout(ctx, user, info)
|
||||||
|
if !ok {
|
||||||
|
goto Default
|
||||||
|
}
|
||||||
|
|
||||||
|
redirect = clientRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
Default:
|
||||||
|
if err = s.sessionService.RevokeToken(ctx, sessionToken, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return redirect, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) RegisterClient(c authn.Client) {
|
func (s *Service) RegisterClient(c authn.Client) {
|
||||||
s.clients[c.Name()] = c
|
s.clients[c.Name()] = c
|
||||||
if cac, ok := c.(authn.ContextAwareClient); ok {
|
if cac, ok := c.(authn.ContextAwareClient); ok {
|
||||||
|
@ -13,10 +13,14 @@ import (
|
|||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||||
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/authinfotest"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
@ -299,6 +303,95 @@ func TestService_RedirectURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestService_Logout(t *testing.T) {
|
||||||
|
type TestCase struct {
|
||||||
|
desc string
|
||||||
|
|
||||||
|
identity *authn.Identity
|
||||||
|
sessionToken *usertoken.UserToken
|
||||||
|
info *login.UserAuth
|
||||||
|
|
||||||
|
client authn.Client
|
||||||
|
|
||||||
|
expectedErr error
|
||||||
|
expectedTokenRevoked bool
|
||||||
|
expectedRedirect *authn.Redirect
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []TestCase{
|
||||||
|
{
|
||||||
|
desc: "should redirect to default redirect url when identity is not a user",
|
||||||
|
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceServiceAccount, 1)},
|
||||||
|
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should redirect to default redirect url when no external provider was used to authenticate",
|
||||||
|
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||||
|
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||||
|
expectedTokenRevoked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should redirect to default redirect url when client is not found",
|
||||||
|
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||||
|
info: &login.UserAuth{AuthModule: "notFound"},
|
||||||
|
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||||
|
expectedTokenRevoked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should redirect to default redirect url when client do not implement logout extension",
|
||||||
|
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||||
|
info: &login.UserAuth{AuthModule: "azuread"},
|
||||||
|
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
|
||||||
|
client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"},
|
||||||
|
expectedTokenRevoked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should redirect to client specific url",
|
||||||
|
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
|
||||||
|
info: &login.UserAuth{AuthModule: "azuread"},
|
||||||
|
expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"},
|
||||||
|
client: &authntest.MockClient{
|
||||||
|
NameFunc: func() string { return "auth.client.azuread" },
|
||||||
|
LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) {
|
||||||
|
return &authn.Redirect{URL: "http://idp.com/logout"}, true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedTokenRevoked: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
var tokenRevoked bool
|
||||||
|
|
||||||
|
s := setupTests(t, func(svc *Service) {
|
||||||
|
if tt.client != nil {
|
||||||
|
svc.RegisterClient(tt.client)
|
||||||
|
}
|
||||||
|
svc.cfg.AppSubURL = "http://localhost:3000"
|
||||||
|
svc.authInfoService = &authinfotest.FakeService{
|
||||||
|
ExpectedUserAuth: tt.info,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.sessionService = &authtest.FakeUserAuthTokenService{
|
||||||
|
RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error {
|
||||||
|
tokenRevoked = true
|
||||||
|
assert.EqualValues(t, tt.sessionToken, sessionToken)
|
||||||
|
assert.False(t, soft)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
redirect, err := s.Logout(context.Background(), tt.identity, tt.sessionToken)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
assert.EqualValues(t, tt.expectedRedirect, redirect)
|
||||||
|
assert.Equal(t, tt.expectedTokenRevoked, tokenRevoked)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func mustParseURL(s string) *url.URL {
|
func mustParseURL(s string) *url.URL {
|
||||||
u, err := url.Parse(s)
|
u, err := url.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -3,6 +3,8 @@ package authntest
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,6 +68,10 @@ func (f *FakeService) RedirectURL(ctx context.Context, client string, r *authn.R
|
|||||||
return f.ExpectedRedirect, f.ExpectedErr
|
return f.ExpectedRedirect, f.ExpectedErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*FakeService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
|
||||||
|
panic("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (f *FakeService) RegisterClient(c authn.Client) {}
|
func (f *FakeService) RegisterClient(c authn.Client) {}
|
||||||
|
|
||||||
func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
||||||
|
@ -3,7 +3,10 @@ package authntest
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ authn.Service = new(MockService)
|
var _ authn.Service = new(MockService)
|
||||||
@ -40,6 +43,10 @@ func (m *MockService) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority
|
|||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*MockService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
|
||||||
|
panic("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
|
||||||
if m.SyncIdentityFunc != nil {
|
if m.SyncIdentityFunc != nil {
|
||||||
return m.SyncIdentityFunc(ctx, identity)
|
return m.SyncIdentityFunc(ctx, identity)
|
||||||
@ -48,6 +55,7 @@ func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity
|
|||||||
}
|
}
|
||||||
|
|
||||||
var _ authn.HookClient = new(MockClient)
|
var _ authn.HookClient = new(MockClient)
|
||||||
|
var _ authn.LogoutClient = new(MockClient)
|
||||||
var _ authn.ContextAwareClient = new(MockClient)
|
var _ authn.ContextAwareClient = new(MockClient)
|
||||||
|
|
||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
@ -56,6 +64,7 @@ type MockClient struct {
|
|||||||
TestFunc func(ctx context.Context, r *authn.Request) bool
|
TestFunc func(ctx context.Context, r *authn.Request) bool
|
||||||
PriorityFunc func() uint
|
PriorityFunc func() uint
|
||||||
HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error
|
HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error
|
||||||
|
LogoutFunc func(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m MockClient) Name() string {
|
func (m MockClient) Name() string {
|
||||||
@ -93,6 +102,13 @@ func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
|
||||||
|
if m.LogoutFunc != nil {
|
||||||
|
return m.LogoutFunc(ctx, user, info)
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
var _ authn.ProxyClient = new(MockProxyClient)
|
var _ authn.ProxyClient = new(MockProxyClient)
|
||||||
|
|
||||||
type MockProxyClient struct {
|
type MockProxyClient struct {
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@ -16,8 +17,10 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/login/social/connectors"
|
"github.com/grafana/grafana/pkg/login/social/connectors"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/util/errutil"
|
"github.com/grafana/grafana/pkg/util/errutil"
|
||||||
@ -30,9 +33,10 @@ const (
|
|||||||
codeChallengeMethodParamName = "code_challenge_method"
|
codeChallengeMethodParamName = "code_challenge_method"
|
||||||
codeChallengeMethod = "S256"
|
codeChallengeMethod = "S256"
|
||||||
|
|
||||||
oauthStateQueryName = "state"
|
oauthStateQueryName = "state"
|
||||||
oauthStateCookieName = "oauth_state"
|
oauthStateCookieName = "oauth_state"
|
||||||
oauthPKCECookieName = "oauth_code_verifier"
|
oauthPKCECookieName = "oauth_code_verifier"
|
||||||
|
oauthPostLogoutRedirectParam = "post_logout_redirect_uri"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -54,26 +58,28 @@ func fromSocialErr(err *connectors.SocialError) error {
|
|||||||
return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err)
|
return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ authn.LogoutClient = new(OAuth)
|
||||||
var _ authn.RedirectClient = new(OAuth)
|
var _ authn.RedirectClient = new(OAuth)
|
||||||
|
|
||||||
func ProvideOAuth(
|
func ProvideOAuth(
|
||||||
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
||||||
connector social.SocialConnector, httpClient *http.Client,
|
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
|
||||||
) *OAuth {
|
) *OAuth {
|
||||||
return &OAuth{
|
return &OAuth{
|
||||||
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
||||||
log.New(name), cfg, oauthCfg, connector, httpClient,
|
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuth struct {
|
type OAuth struct {
|
||||||
name string
|
name string
|
||||||
moduleName string
|
moduleName string
|
||||||
log log.Logger
|
log log.Logger
|
||||||
cfg *setting.Cfg
|
cfg *setting.Cfg
|
||||||
oauthCfg *social.OAuthInfo
|
oauthCfg *social.OAuthInfo
|
||||||
connector social.SocialConnector
|
connector social.SocialConnector
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
oauthService oauthtoken.OAuthTokenService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OAuth) Name() string {
|
func (c *OAuth) Name() string {
|
||||||
@ -204,6 +210,29 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
|
||||||
|
token := c.oauthService.GetCurrentOAuthToken(ctx, user)
|
||||||
|
|
||||||
|
if err := c.oauthService.InvalidateOAuthTokens(ctx, info); err != nil {
|
||||||
|
namespace, id := user.GetNamespacedID()
|
||||||
|
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
|
||||||
|
if redirctURL == "" {
|
||||||
|
c.log.FromContext(ctx).Debug("No signout redirect url configured")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
|
||||||
|
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||||
|
redirctURL = withIDTokenHint(redirctURL, idToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &authn.Redirect{URL: redirctURL}, true
|
||||||
|
}
|
||||||
|
|
||||||
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
||||||
func genPKCECode() (string, string, error) {
|
func genPKCECode() (string, string, error) {
|
||||||
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
||||||
@ -243,3 +272,43 @@ func hashOAuthState(state, secret, seed string) string {
|
|||||||
hashBytes := sha256.Sum256([]byte(state + secret + seed))
|
hashBytes := sha256.Sum256([]byte(state + secret + seed))
|
||||||
return hex.EncodeToString(hashBytes[:])
|
return hex.EncodeToString(hashBytes[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getOAuthSignoutRedirectURL(cfg *setting.Cfg, oauthCfg *social.OAuthInfo) string {
|
||||||
|
if oauthCfg.SignoutRedirectUrl != "" {
|
||||||
|
return oauthCfg.SignoutRedirectUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg.SignoutRedirectUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
func withIDTokenHint(redirectURL string, idToken string) string {
|
||||||
|
if idToken == "" {
|
||||||
|
return redirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(redirectURL)
|
||||||
|
if err != nil {
|
||||||
|
return redirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("id_token_hint", idToken)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOICDLogout(redirectUrl string) bool {
|
||||||
|
if redirectUrl == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(redirectUrl)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
_, ok := q[oauthPostLogoutRedirectParam]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
@ -12,8 +14,10 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
@ -212,7 +216,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
ExpectedToken: &oauth2.Token{},
|
ExpectedToken: &oauth2.Token{},
|
||||||
ExpectedIsSignupAllowed: true,
|
ExpectedIsSignupAllowed: true,
|
||||||
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
||||||
}, nil)
|
}, nil, nil)
|
||||||
identity, err := c.Authenticate(context.Background(), tt.req)
|
identity, err := c.Authenticate(context.Background(), tt.req)
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
|
||||||
@ -281,7 +285,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
|||||||
require.Len(t, opts, tt.numCallOptions)
|
require.Len(t, opts, tt.numCallOptions)
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil, nil)
|
||||||
|
|
||||||
redirect, err := c.RedirectURL(context.Background(), nil)
|
redirect, err := c.RedirectURL(context.Background(), nil)
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
@ -299,6 +303,107 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOAuth_Logout(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
cfg *setting.Cfg
|
||||||
|
oauthCfg *social.OAuthInfo
|
||||||
|
|
||||||
|
expectedOK bool
|
||||||
|
expectedURL string
|
||||||
|
expectedIDTokenHint string
|
||||||
|
expectedPostLogoutURI string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should not return redirect url if not configured for client or globably",
|
||||||
|
cfg: &setting.Cfg{},
|
||||||
|
oauthCfg: &social.OAuthInfo{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should return redirect url for globably configured redirect url",
|
||||||
|
cfg: &setting.Cfg{
|
||||||
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
|
},
|
||||||
|
oauthCfg: &social.OAuthInfo{},
|
||||||
|
expectedURL: "http://idp.com/logout",
|
||||||
|
expectedOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should return redirect url for client configured redirect url",
|
||||||
|
cfg: &setting.Cfg{},
|
||||||
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
|
},
|
||||||
|
expectedURL: "http://idp.com/logout",
|
||||||
|
expectedOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "client specific url should take precedence",
|
||||||
|
cfg: &setting.Cfg{
|
||||||
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
|
},
|
||||||
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
SignoutRedirectUrl: "http://idp-2.com/logout",
|
||||||
|
},
|
||||||
|
expectedURL: "http://idp-2.com/logout",
|
||||||
|
expectedOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should add id token hint if oicd logout is configured and token is valid",
|
||||||
|
cfg: &setting.Cfg{},
|
||||||
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
|
||||||
|
},
|
||||||
|
expectedURL: "http://idp.com/logout",
|
||||||
|
expectedIDTokenHint: "id_token_hint=some.id.token",
|
||||||
|
expectedPostLogoutURI: "http%3A%3A%2F%2Ftest.com%2Flogin",
|
||||||
|
expectedOK: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
getTokenCalled bool
|
||||||
|
invalidateTokenCalled bool
|
||||||
|
)
|
||||||
|
|
||||||
|
mockService := &oauthtokentest.MockOauthTokenService{
|
||||||
|
GetCurrentOauthTokenFunc: func(_ context.Context, _ identity.Requester) *oauth2.Token {
|
||||||
|
getTokenCalled = true
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "some.access.token",
|
||||||
|
Expiry: time.Now().Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
return token.WithExtra(map[string]any{
|
||||||
|
"id_token": "some.id.token",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
|
||||||
|
invalidateTokenCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
|
||||||
|
|
||||||
|
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedOK, ok)
|
||||||
|
if tt.expectedOK {
|
||||||
|
assert.True(t, strings.HasPrefix(redirect.URL, tt.expectedURL))
|
||||||
|
assert.Contains(t, redirect.URL, tt.expectedIDTokenHint)
|
||||||
|
assert.Contains(t, redirect.URL, tt.expectedPostLogoutURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, getTokenCalled)
|
||||||
|
assert.True(t, invalidateTokenCalled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type mockConnector struct {
|
type mockConnector struct {
|
||||||
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
|
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
|
||||||
social.SocialConnector
|
social.SocialConnector
|
||||||
|
Loading…
Reference in New Issue
Block a user