Session: set authID and authenticatedBy (#85806)

* Authn: Resolve authenticate by and auth id when fethcing signed in user

* Change logout client interface to only take Requester interface

* Session: Fetch external auth info when authenticating sessions

* Use authenticated by from identity

* Move call to get auth-info into session client and use GetAuthenticatedBy in various places
This commit is contained in:
Karl Persson
2024-04-11 10:25:29 +02:00
committed by GitHub
parent f375af793f
commit 895222725c
21 changed files with 230 additions and 185 deletions

View File

@@ -43,7 +43,7 @@ func ProvideRegistration(
authnSvc.RegisterClient(clients.ProvideAPIKey(apikeyService))
if cfg.LoginCookieName != "" {
authnSvc.RegisterClient(clients.ProvideSession(cfg, sessionService))
authnSvc.RegisterClient(clients.ProvideSession(cfg, sessionService, authInfoService))
}
var proxyClients []authn.ProxyClient

View File

@@ -19,7 +19,6 @@ import (
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/clients"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
@@ -47,20 +46,18 @@ func ProvideIdentitySynchronizer(s *Service) authn.IdentitySynchronizer {
func ProvideService(
cfg *setting.Cfg, tracer tracing.Tracer,
sessionService auth.UserTokenService, usageStats usagestats.Service,
authInfoService login.AuthInfoService, registerer prometheus.Registerer,
sessionService auth.UserTokenService, usageStats usagestats.Service, registerer prometheus.Registerer,
) *Service {
s := &Service{
log: log.New("authn.service"),
cfg: cfg,
clients: make(map[string]authn.Client),
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracer,
metrics: newMetrics(registerer),
authInfoService: authInfoService,
sessionService: sessionService,
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
log: log.New("authn.service"),
cfg: cfg,
clients: make(map[string]authn.Client),
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracer,
metrics: newMetrics(registerer),
sessionService: sessionService,
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
}
usageStats.RegisterMetricsFunc(s.getUsageStats)
@@ -77,8 +74,7 @@ type Service struct {
tracer tracing.Tracer
metrics *metrics
authInfoService login.AuthInfoService
sessionService auth.UserTokenService
sessionService auth.UserTokenService
// postAuthHooks are called after a successful authentication. They can modify the identity.
postAuthHooks *queue[authn.PostAuthHookFn]
@@ -259,9 +255,8 @@ func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionTo
return redirect, nil
}
info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID})
if info != nil {
client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_"))
if authModule := user.GetAuthenticatedBy(); authModule != "" {
client := authn.ClientWithPrefix(strings.TrimPrefix(authModule, "oauth_"))
c, ok := s.clients[client]
if !ok {
@@ -275,7 +270,7 @@ func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionTo
goto Default
}
clientRedirect, ok := logoutClient.Logout(ctx, user, info)
clientRedirect, ok := logoutClient.Logout(ctx, user)
if !ok {
goto Default
}

View File

@@ -19,8 +19,6 @@ import (
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
@@ -309,7 +307,6 @@ func TestService_Logout(t *testing.T) {
identity *authn.Identity
sessionToken *usertoken.UserToken
info *login.UserAuth
client authn.Client
@@ -332,27 +329,24 @@ func TestService_Logout(t *testing.T) {
},
{
desc: "should redirect to default redirect url when client is not found",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "notFound"},
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "notfound"},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
expectedTokenRevoked: true,
},
{
desc: "should redirect to default redirect url when client do not implement logout extension",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "azuread"},
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "azuread"},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"},
expectedTokenRevoked: true,
},
{
desc: "should redirect to client specific url",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "azuread"},
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "azuread"},
expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"},
client: &authntest.MockClient{
NameFunc: func() string { return "auth.client.azuread" },
LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) {
LogoutFunc: func(ctx context.Context, _ identity.Requester) (*authn.Redirect, bool) {
return &authn.Redirect{URL: "http://idp.com/logout"}, true
},
},
@@ -369,9 +363,6 @@ func TestService_Logout(t *testing.T) {
svc.RegisterClient(tt.client)
}
svc.cfg.AppSubURL = "http://localhost:3000"
svc.authInfoService = &authinfotest.FakeService{
ExpectedUserAuth: tt.info,
}
svc.sessionService = &authtest.FakeUserAuthTokenService{
RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error {

View File

@@ -3,6 +3,7 @@ package sync
import (
"context"
"errors"
"strings"
"time"
"golang.org/x/sync/singleflight"
@@ -39,11 +40,16 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
// not authenticated through session tokens, so we can skip this hook
// Not authenticated through session tokens, so we can skip this hook.
if identity.SessionToken == nil {
return nil
}
// Not authenticated with a oauth provider, so we can skip this hook.
if !strings.HasPrefix(identity.GetAuthenticatedBy(), "oauth") {
return nil
}
_, err, _ := s.singleflightGroup.Do(identity.ID, func() (interface{}, error) {
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)

View File

@@ -51,7 +51,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
},
{
desc: "should invalidate access token and session token if token refresh fails",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: true,
expectedTryRefreshErr: errors.New("some err"),
expectTryRefreshTokenCalled: true,
@@ -62,7 +62,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
},
{
desc: "should refresh the token successfully",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: false,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
@@ -70,7 +70,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
},
{
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/grafana/grafana/pkg/infra/log"
authidentity "github.com/grafana/grafana/pkg/services/auth/identity"
@@ -110,12 +111,13 @@ func (s *UserSync) FetchSyncedUserHook(ctx context.Context, identity *authn.Iden
if !identity.ClientParams.FetchSyncedUser {
return nil
}
namespace, id := identity.GetNamespacedID()
if namespace != authn.NamespaceUser && namespace != authn.NamespaceServiceAccount {
if !authidentity.IsNamespace(namespace, authn.NamespaceUser, authn.NamespaceServiceAccount) {
return nil
}
userID, err := authidentity.IntIdentifier(namespace, id)
userID, err := strconv.ParseInt(id, 10, 64)
if err != nil {
s.log.FromContext(ctx).Warn("got invalid identity ID", "id", id, "err", err)
return nil