From 895222725c17caf93b642b60a8bea996333c43af Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Thu, 11 Apr 2024 10:25:29 +0200 Subject: [PATCH] Session: set authID and authenticatedBy (#85806) * Authn: Resolve authenticate by and auth id when fethcing signed in user * Change logout client interface to only take Requester interface * Session: Fetch external auth info when authenticating sessions * Use authenticated by from identity * Move call to get auth-info into session client and use GetAuthenticatedBy in various places --- pkg/api/index.go | 47 ++------------- pkg/api/login.go | 8 +-- pkg/api/login_test.go | 3 +- pkg/services/auth/identity/requester.go | 7 ++- pkg/services/auth/idimpl/service.go | 54 ++++------------- pkg/services/auth/idimpl/service_test.go | 12 ++-- pkg/services/authn/authn.go | 2 +- pkg/services/authn/authnimpl/registration.go | 2 +- pkg/services/authn/authnimpl/service.go | 33 +++++------ pkg/services/authn/authnimpl/service_test.go | 17 ++---- .../authn/authnimpl/sync/oauth_token_sync.go | 8 ++- .../authnimpl/sync/oauth_token_sync_test.go | 6 +- .../authn/authnimpl/sync/user_sync.go | 6 +- pkg/services/authn/authntest/mock.go | 7 +-- pkg/services/authn/clients/oauth.go | 15 ++++- pkg/services/authn/clients/oauth_test.go | 2 +- pkg/services/authn/clients/session.go | 35 ++++++++--- pkg/services/authn/clients/session_test.go | 58 +++++++++++++++++-- pkg/services/authn/identity.go | 5 ++ pkg/services/user/identity.go | 34 ++++++----- pkg/tests/web/index_view_test.go | 54 ++++++++++++++--- 21 files changed, 230 insertions(+), 185 deletions(-) diff --git a/pkg/api/index.go b/pkg/api/index.go index 3eb92001f94..7868ddce08f 100644 --- a/pkg/api/index.go +++ b/pkg/api/index.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" - "errors" "fmt" "net/http" "strings" @@ -20,7 +19,6 @@ import ( "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/org" pref "github.com/grafana/grafana/pkg/services/preference" - "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -112,7 +110,7 @@ func (hs *HTTPServer) setIndexViewData(c *contextmodel.ReqContext) (*dtos.IndexV HelpFlags1: c.HelpFlags1, HasEditPermissionInFolders: hasEditPerm, Analytics: hs.buildUserAnalyticsSettings(c), - AuthenticatedBy: hs.getUserAuthenticatedBy(c, userID), + AuthenticatedBy: c.GetAuthenticatedBy(), }, Settings: settings, ThemeType: theme.Type, @@ -168,7 +166,7 @@ func (hs *HTTPServer) setIndexViewData(c *contextmodel.ReqContext) (*dtos.IndexV } func (hs *HTTPServer) buildUserAnalyticsSettings(c *contextmodel.ReqContext) dtos.AnalyticsSettings { - namespace, id := c.SignedInUser.GetNamespacedID() + namespace, _ := c.SignedInUser.GetNamespacedID() // Anonymous users do not have an email or auth info if namespace != identity.NamespaceUser { @@ -179,21 +177,10 @@ func (hs *HTTPServer) buildUserAnalyticsSettings(c *contextmodel.ReqContext) dto return dtos.AnalyticsSettings{} } - userID, err := identity.IntIdentifier(namespace, id) - if err != nil { - hs.log.Error("Failed to parse user ID", "error", err) - return dtos.AnalyticsSettings{Identifier: "@" + hs.Cfg.AppURL} - } - identifier := c.SignedInUser.GetEmail() + "@" + hs.Cfg.AppURL - authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &login.GetAuthInfoQuery{UserId: userID}) - if err != nil && !errors.Is(err, user.ErrUserNotFound) { - hs.log.Error("Failed to get auth info for analytics", "error", err) - } - - if authInfo != nil && authInfo.AuthModule == login.GrafanaComAuthModule { - identifier = authInfo.AuthId + if authenticatedBy := c.SignedInUser.GetAuthenticatedBy(); authenticatedBy == login.GrafanaComAuthModule { + identifier = c.SignedInUser.GetAuthID() } return dtos.AnalyticsSettings{ @@ -216,32 +203,6 @@ func (hs *HTTPServer) getUserOrgCount(c *contextmodel.ReqContext, userID int64) return len(userOrgs) } -// getUserAuthenticatedBy returns external authentication method used for user. -// If user does not have an external authentication method an empty string is returned -func (hs *HTTPServer) getUserAuthenticatedBy(c *contextmodel.ReqContext, userID int64) string { - if userID == 0 { - return "" - } - - // Special case for image renderer. Frontend relies on this information - // to render dashboards in a bit different way. - if c.IsRenderCall { - return login.RenderModule - } - - info, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &login.GetAuthInfoQuery{UserId: userID}) - // we ignore errors where a user does not have external user auth - if err != nil && !errors.Is(err, user.ErrUserNotFound) { - hs.log.FromContext(c.Req.Context()).Error("Failed to fetch auth info", "userId", c.SignedInUser.UserID, "error", err) - } - - if err != nil { - return "" - } - - return info.AuthModule -} - func hashUserIdentifier(identifier string, secret string) string { if secret == "" { return "" diff --git a/pkg/api/login.go b/pkg/api/login.go index 8e5d570044d..c1638600678 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -243,13 +243,7 @@ func (hs *HTTPServer) loginUserWithUser(user *user.User, c *contextmodel.ReqCont func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) { // FIXME: restructure saml client to implement authn.LogoutClient if hs.samlSingleLogoutEnabled() { - id, err := identity.UserIdentifier(c.SignedInUser.GetNamespacedID()) - if err != nil { - hs.log.Error("failed to retrieve user ID", "error", err) - } - - authInfo, _ := hs.authInfoService.GetAuthInfo(c.Req.Context(), &loginservice.GetAuthInfoQuery{UserId: id}) - if authInfo != nil && authInfo.AuthModule == loginservice.SAMLAuthModule { + if c.SignedInUser.GetAuthenticatedBy() == loginservice.SAMLAuthModule { c.Redirect(hs.Cfg.AppSubURL + "/logout/saml") return } diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index 2b0ef5df7c3..4c8dcd77742 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -670,7 +670,8 @@ func TestLogoutSaml(t *testing.T) { assert.Equal(t, true, hs.samlSingleLogoutEnabled()) sc.defaultHandler = routing.Wrap(func(c *contextmodel.ReqContext) response.Response { c.SignedInUser = &user.SignedInUser{ - UserID: 1, + UserID: 1, + AuthenticatedBy: loginservice.SAMLAuthModule, } hs.Logout(c) return response.Empty(http.StatusOK) diff --git a/pkg/services/auth/identity/requester.go b/pkg/services/auth/identity/requester.go index 9bdd2a1cfd2..57c82e9ecd9 100644 --- a/pkg/services/auth/identity/requester.go +++ b/pkg/services/auth/identity/requester.go @@ -53,9 +53,12 @@ type Requester interface { // DEPRECATED: GetOrgName returns the name of the active organization. // Retrieve the organization name from the organization service instead of using this method. GetOrgName() string + // GetAuthID returns external id for entity. + GetAuthID() string + // GetAuthenticatedBy returns the authentication method used to authenticate the entity. + GetAuthenticatedBy() string // IsAuthenticatedBy returns true if entity was authenticated by any of supplied providers. IsAuthenticatedBy(providers ...string) bool - // IsNil returns true if the identity is nil // FIXME: remove this method once all services are using an interface IsNil() bool @@ -69,8 +72,6 @@ type Requester interface { GetCacheKey() string // HasUniqueId returns true if the entity has a unique id HasUniqueId() bool - // AuthenticatedBy returns the authentication method used to authenticate the entity. - GetAuthenticatedBy() string // GetIDToken returns a signed token representing the identity that can be forwarded to plugins and external services. // Will only be set when featuremgmt.FlagIdForwarding is enabled. GetIDToken() string diff --git a/pkg/services/auth/idimpl/service.go b/pkg/services/auth/idimpl/service.go index 13f897260fe..67cc9bc1e72 100644 --- a/pkg/services/auth/idimpl/service.go +++ b/pkg/services/auth/idimpl/service.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "time" "github.com/go-jose/go-jose/v3/jwt" @@ -17,8 +16,6 @@ import ( "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -33,12 +30,12 @@ var _ auth.IDService = (*Service)(nil) func ProvideService( cfg *setting.Cfg, signer auth.IDSigner, cache remotecache.CacheStorage, features featuremgmt.FeatureToggles, authnService authn.Service, - authInfoService login.AuthInfoService, reg prometheus.Registerer, + reg prometheus.Registerer, ) *Service { s := &Service{ cfg: cfg, logger: log.New("id-service"), signer: signer, cache: cache, - authInfoService: authInfoService, metrics: newMetrics(reg), + metrics: newMetrics(reg), } if features.IsEnabledGlobally(featuremgmt.FlagIdForwarding) { @@ -49,13 +46,12 @@ func ProvideService( } type Service struct { - cfg *setting.Cfg - logger log.Logger - signer auth.IDSigner - cache remotecache.CacheStorage - authInfoService login.AuthInfoService - si singleflight.Group - metrics *metrics + cfg *setting.Cfg + logger log.Logger + signer auth.IDSigner + cache remotecache.CacheStorage + si singleflight.Group + metrics *metrics } func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (string, error) { @@ -90,9 +86,9 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri } if identity.IsNamespace(namespace, identity.NamespaceUser) { - if err := s.setUserClaims(ctx, id, identifier, claims); err != nil { - return "", err - } + claims.Email = id.GetEmail() + claims.EmailVerified = id.IsEmailVerified() + claims.AuthenticatedBy = id.GetAuthenticatedBy() } token, err := s.signer.SignIDToken(ctx, claims) @@ -134,34 +130,6 @@ func (s *Service) RemoveIDToken(ctx context.Context, id identity.Requester) erro return s.cache.Delete(ctx, prefixCacheKey(id.GetCacheKey())) } -func (s *Service) setUserClaims(ctx context.Context, ident identity.Requester, identifier string, claims *auth.IDClaims) error { - id, err := strconv.ParseInt(identifier, 10, 64) - if err != nil { - return err - } - - if id == 0 { - return nil - } - - claims.Email = ident.GetEmail() - claims.EmailVerified = ident.IsEmailVerified() - - info, err := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: id}) - if err != nil { - // we ignore errors when a user don't have external user auth - if !errors.Is(err, user.ErrUserNotFound) { - s.logger.FromContext(ctx).Error("Failed to fetch auth info", "userId", id, "error", err) - } - - return nil - } - - claims.AuthenticatedBy = info.AuthModule - - return nil -} - func (s *Service) hook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { // FIXME(kalleep): we should probably lazy load this token, err := s.SignIdentity(ctx, identity) diff --git a/pkg/services/auth/idimpl/service_test.go b/pkg/services/auth/idimpl/service_test.go index a5e746904b5..37e89229460 100644 --- a/pkg/services/auth/idimpl/service_test.go +++ b/pkg/services/auth/idimpl/service_test.go @@ -16,8 +16,6 @@ import ( "github.com/grafana/grafana/pkg/services/authn/authntest" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/login/authinfotest" - "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -32,7 +30,7 @@ func Test_ProvideService(t *testing.T) { }, } - _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil, nil) + _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil) assert.True(t, hookRegistered) }) @@ -46,7 +44,7 @@ func Test_ProvideService(t *testing.T) { }, } - _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil, nil) + _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil) assert.False(t, hookRegistered) }) } @@ -69,7 +67,7 @@ func TestService_SignIdentity(t *testing.T) { s := ProvideService( setting.NewCfg(), signer, remotecache.NewFakeCacheStorage(), featuremgmt.WithFeatures(featuremgmt.FlagIdForwarding), - &authntest.FakeService{}, &authinfotest.FakeService{ExpectedError: user.ErrUserNotFound}, nil, + &authntest.FakeService{}, nil, ) token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"}) require.NoError(t, err) @@ -80,9 +78,9 @@ func TestService_SignIdentity(t *testing.T) { s := ProvideService( setting.NewCfg(), signer, remotecache.NewFakeCacheStorage(), featuremgmt.WithFeatures(featuremgmt.FlagIdForwarding), - &authntest.FakeService{}, &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{AuthModule: login.AzureADAuthModule}}, nil, + &authntest.FakeService{}, nil, ) - token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"}) + token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1", AuthenticatedBy: login.AzureADAuthModule}) require.NoError(t, err) parsed, err := jwt.ParseSigned(token) diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index 18e0b9ba32f..1132039126b 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -138,7 +138,7 @@ type RedirectClient interface { // that should happen during logout and supports client specific redirect URL. type LogoutClient interface { Client - Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*Redirect, bool) + Logout(ctx context.Context, user identity.Requester) (*Redirect, bool) } type PasswordClient interface { diff --git a/pkg/services/authn/authnimpl/registration.go b/pkg/services/authn/authnimpl/registration.go index f3dd550fb5d..2b9ebe5ffd1 100644 --- a/pkg/services/authn/authnimpl/registration.go +++ b/pkg/services/authn/authnimpl/registration.go @@ -43,7 +43,7 @@ func ProvideRegistration( authnSvc.RegisterClient(clients.ProvideAPIKey(apikeyService)) if cfg.LoginCookieName != "" { - authnSvc.RegisterClient(clients.ProvideSession(cfg, sessionService)) + authnSvc.RegisterClient(clients.ProvideSession(cfg, sessionService, authInfoService)) } var proxyClients []authn.ProxyClient diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index 7b97a480a08..0b77580947d 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -19,7 +19,6 @@ import ( "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/clients" - "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util/errutil" @@ -47,20 +46,18 @@ func ProvideIdentitySynchronizer(s *Service) authn.IdentitySynchronizer { func ProvideService( cfg *setting.Cfg, tracer tracing.Tracer, - sessionService auth.UserTokenService, usageStats usagestats.Service, - authInfoService login.AuthInfoService, registerer prometheus.Registerer, + sessionService auth.UserTokenService, usageStats usagestats.Service, registerer prometheus.Registerer, ) *Service { s := &Service{ - log: log.New("authn.service"), - cfg: cfg, - clients: make(map[string]authn.Client), - clientQueue: newQueue[authn.ContextAwareClient](), - tracer: tracer, - metrics: newMetrics(registerer), - authInfoService: authInfoService, - sessionService: sessionService, - postAuthHooks: newQueue[authn.PostAuthHookFn](), - postLoginHooks: newQueue[authn.PostLoginHookFn](), + log: log.New("authn.service"), + cfg: cfg, + clients: make(map[string]authn.Client), + clientQueue: newQueue[authn.ContextAwareClient](), + tracer: tracer, + metrics: newMetrics(registerer), + sessionService: sessionService, + postAuthHooks: newQueue[authn.PostAuthHookFn](), + postLoginHooks: newQueue[authn.PostLoginHookFn](), } usageStats.RegisterMetricsFunc(s.getUsageStats) @@ -77,8 +74,7 @@ type Service struct { tracer tracing.Tracer metrics *metrics - authInfoService login.AuthInfoService - sessionService auth.UserTokenService + sessionService auth.UserTokenService // postAuthHooks are called after a successful authentication. They can modify the identity. postAuthHooks *queue[authn.PostAuthHookFn] @@ -259,9 +255,8 @@ func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionTo return redirect, nil } - info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID}) - if info != nil { - client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_")) + if authModule := user.GetAuthenticatedBy(); authModule != "" { + client := authn.ClientWithPrefix(strings.TrimPrefix(authModule, "oauth_")) c, ok := s.clients[client] if !ok { @@ -275,7 +270,7 @@ func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionTo goto Default } - clientRedirect, ok := logoutClient.Logout(ctx, user, info) + clientRedirect, ok := logoutClient.Logout(ctx, user) if !ok { goto Default } diff --git a/pkg/services/authn/authnimpl/service_test.go b/pkg/services/authn/authnimpl/service_test.go index b053569d5ee..c65bc7d7cc9 100644 --- a/pkg/services/authn/authnimpl/service_test.go +++ b/pkg/services/authn/authnimpl/service_test.go @@ -19,8 +19,6 @@ import ( "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authntest" - "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/login/authinfotest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -309,7 +307,6 @@ func TestService_Logout(t *testing.T) { identity *authn.Identity sessionToken *usertoken.UserToken - info *login.UserAuth client authn.Client @@ -332,27 +329,24 @@ func TestService_Logout(t *testing.T) { }, { desc: "should redirect to default redirect url when client is not found", - identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, - info: &login.UserAuth{AuthModule: "notFound"}, + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "notfound"}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, expectedTokenRevoked: true, }, { desc: "should redirect to default redirect url when client do not implement logout extension", - identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, - info: &login.UserAuth{AuthModule: "azuread"}, + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "azuread"}, expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"}, client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"}, expectedTokenRevoked: true, }, { desc: "should redirect to client specific url", - identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)}, - info: &login.UserAuth{AuthModule: "azuread"}, + identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), AuthenticatedBy: "azuread"}, expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"}, client: &authntest.MockClient{ NameFunc: func() string { return "auth.client.azuread" }, - LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) { + LogoutFunc: func(ctx context.Context, _ identity.Requester) (*authn.Redirect, bool) { return &authn.Redirect{URL: "http://idp.com/logout"}, true }, }, @@ -369,9 +363,6 @@ func TestService_Logout(t *testing.T) { svc.RegisterClient(tt.client) } svc.cfg.AppSubURL = "http://localhost:3000" - svc.authInfoService = &authinfotest.FakeService{ - ExpectedUserAuth: tt.info, - } svc.sessionService = &authtest.FakeUserAuthTokenService{ RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error { diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go index 8756ce603cb..935b4ec68c9 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go @@ -3,6 +3,7 @@ package sync import ( "context" "errors" + "strings" "time" "golang.org/x/sync/singleflight" @@ -39,11 +40,16 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn return nil } - // not authenticated through session tokens, so we can skip this hook + // Not authenticated through session tokens, so we can skip this hook. if identity.SessionToken == nil { return nil } + // Not authenticated with a oauth provider, so we can skip this hook. + if !strings.HasPrefix(identity.GetAuthenticatedBy(), "oauth") { + return nil + } + _, err, _ := s.singleflightGroup.Do(identity.ID, func() (interface{}, error) { s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID) diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go index 9bd0863c82c..03db9ff1dd0 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go @@ -51,7 +51,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { }, { desc: "should invalidate access token and session token if token refresh fails", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule}, expectHasEntryCalled: true, expectedTryRefreshErr: errors.New("some err"), expectTryRefreshTokenCalled: true, @@ -62,7 +62,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { }, { desc: "should refresh the token successfully", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule}, expectHasEntryCalled: false, expectTryRefreshTokenCalled: true, expectInvalidateOauthTokensCalled: false, @@ -70,7 +70,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { }, { desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)", - identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule}, expectHasEntryCalled: true, expectTryRefreshTokenCalled: true, expectInvalidateOauthTokensCalled: false, diff --git a/pkg/services/authn/authnimpl/sync/user_sync.go b/pkg/services/authn/authnimpl/sync/user_sync.go index 3c65d8096eb..23d0d6f9042 100644 --- a/pkg/services/authn/authnimpl/sync/user_sync.go +++ b/pkg/services/authn/authnimpl/sync/user_sync.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/grafana/grafana/pkg/infra/log" authidentity "github.com/grafana/grafana/pkg/services/auth/identity" @@ -110,12 +111,13 @@ func (s *UserSync) FetchSyncedUserHook(ctx context.Context, identity *authn.Iden if !identity.ClientParams.FetchSyncedUser { return nil } + namespace, id := identity.GetNamespacedID() - if namespace != authn.NamespaceUser && namespace != authn.NamespaceServiceAccount { + if !authidentity.IsNamespace(namespace, authn.NamespaceUser, authn.NamespaceServiceAccount) { return nil } - userID, err := authidentity.IntIdentifier(namespace, id) + userID, err := strconv.ParseInt(id, 10, 64) if err != nil { s.log.FromContext(ctx).Warn("got invalid identity ID", "id", id, "err", err) return nil diff --git a/pkg/services/authn/authntest/mock.go b/pkg/services/authn/authntest/mock.go index 193200e5498..2d8d13b93e5 100644 --- a/pkg/services/authn/authntest/mock.go +++ b/pkg/services/authn/authntest/mock.go @@ -6,7 +6,6 @@ import ( "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" - "github.com/grafana/grafana/pkg/services/login" ) var _ authn.Service = new(MockService) @@ -68,7 +67,7 @@ type MockClient struct { TestFunc func(ctx context.Context, r *authn.Request) bool PriorityFunc func() uint HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error - LogoutFunc func(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) + LogoutFunc func(ctx context.Context, user identity.Requester) (*authn.Redirect, bool) } func (m MockClient) Name() string { @@ -106,9 +105,9 @@ func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn return nil } -func (m *MockClient) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) { +func (m *MockClient) Logout(ctx context.Context, user identity.Requester) (*authn.Redirect, bool) { if m.LogoutFunc != nil { - return m.LogoutFunc(ctx, user, info) + return m.LogoutFunc(ctx, user) } return nil, false } diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index 9cfcec0e635..e9184f41e37 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -241,10 +241,21 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir }, nil } -func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) { +func (c *OAuth) Logout(ctx context.Context, user identity.Requester) (*authn.Redirect, bool) { token := c.oauthService.GetCurrentOAuthToken(ctx, user) - if err := c.oauthService.InvalidateOAuthTokens(ctx, info); err != nil { + namespace, id := user.GetNamespacedID() + userID, err := identity.UserIdentifier(namespace, id) + if err != nil { + c.log.FromContext(ctx).Error("Failed to parse user id", "namespace", namespace, "id", id, "error", err) + return nil, false + } + + if err := c.oauthService.InvalidateOAuthTokens(ctx, &login.UserAuth{ + UserId: userID, + AuthId: user.GetAuthID(), + AuthModule: user.GetAuthenticatedBy(), + }); err != nil { namespace, id := user.GetNamespacedID() c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err) } diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index 7432ac2ad83..a746c2d4633 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -486,7 +486,7 @@ func TestOAuth_Logout(t *testing.T) { } c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc, &setting.OSSImpl{Cfg: tt.cfg}, featuremgmt.WithFeatures()) - redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{}) + redirect, ok := c.Logout(context.Background(), &authn.Identity{}) assert.Equal(t, tt.expectedOK, ok) if tt.expectedOK { diff --git a/pkg/services/authn/clients/session.go b/pkg/services/authn/clients/session.go index 1af47b16449..014192c27d1 100644 --- a/pkg/services/authn/clients/session.go +++ b/pkg/services/authn/clients/session.go @@ -2,29 +2,34 @@ package clients import ( "context" + "errors" "net/url" "time" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) var _ authn.ContextAwareClient = new(Session) -func ProvideSession(cfg *setting.Cfg, sessionService auth.UserTokenService) *Session { +func ProvideSession(cfg *setting.Cfg, sessionService auth.UserTokenService, authInfoService login.AuthInfoService) *Session { return &Session{ - cfg: cfg, - sessionService: sessionService, - log: log.New(authn.ClientSession), + cfg: cfg, + log: log.New(authn.ClientSession), + sessionService: sessionService, + authInfoService: authInfoService, } } type Session struct { - cfg *setting.Cfg - sessionService auth.UserTokenService - log log.Logger + cfg *setting.Cfg + log log.Logger + sessionService auth.UserTokenService + authInfoService login.AuthInfoService } func (s *Session) Name() string { @@ -51,14 +56,26 @@ func (s *Session) Authenticate(ctx context.Context, r *authn.Request) (*authn.Id return nil, authn.ErrTokenNeedsRotation.Errorf("token needs to be rotated") } - return &authn.Identity{ + ident := &authn.Identity{ ID: authn.NamespacedID(authn.NamespaceUser, token.UserId), SessionToken: token, ClientParams: authn.ClientParams{ FetchSyncedUser: true, SyncPermissions: true, }, - }, nil + } + + info, err := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: token.UserId}) + if err != nil { + if !errors.Is(err, user.ErrUserNotFound) { + s.log.FromContext(ctx).Error("Failed to fetch auth info", "err", err) + } + return ident, nil + } + + ident.AuthID = info.AuthId + ident.AuthenticatedBy = info.AuthModule + return ident, nil } func (s *Session) Test(ctx context.Context, r *authn.Request) bool { diff --git a/pkg/services/authn/clients/session_test.go b/pkg/services/authn/clients/session_test.go index 41fc4e5f7ae..0315083ea5b 100644 --- a/pkg/services/authn/clients/session_test.go +++ b/pkg/services/authn/clients/session_test.go @@ -13,6 +13,9 @@ import ( "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/login/authinfotest" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -26,7 +29,7 @@ func TestSession_Test(t *testing.T) { cfg := setting.NewCfg() cfg.LoginCookieName = "" cfg.LoginMaxLifetime = 20 * time.Second - s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{}) + s := ProvideSession(cfg, &authtest.FakeUserAuthTokenService{}, &authinfotest.FakeService{}) disabled := s.Test(context.Background(), &authn.Request{HTTPRequest: validHTTPReq}) assert.False(t, disabled) @@ -60,7 +63,8 @@ func TestSession_Authenticate(t *testing.T) { } type fields struct { - sessionService auth.UserTokenService + authInfoService login.AuthInfoService + sessionService auth.UserTokenService } type args struct { r *authn.Request @@ -75,7 +79,8 @@ func TestSession_Authenticate(t *testing.T) { { name: "cookie not found", fields: fields{ - sessionService: &authtest.FakeUserAuthTokenService{}, + sessionService: &authtest.FakeUserAuthTokenService{}, + authInfoService: &authinfotest.FakeService{}, }, args: args{r: &authn.Request{HTTPRequest: &http.Request{}}}, wantID: nil, @@ -87,6 +92,7 @@ func TestSession_Authenticate(t *testing.T) { sessionService: &authtest.FakeUserAuthTokenService{LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { return validToken, nil }}, + authInfoService: &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{}}, }, args: args{r: &authn.Request{HTTPRequest: validHTTPReq}}, wantID: &authn.Identity{ @@ -108,6 +114,7 @@ func TestSession_Authenticate(t *testing.T) { RotatedAt: time.Now().Add(-11 * time.Minute).Unix(), }, nil }}, + authInfoService: &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{}}, }, args: args{r: &authn.Request{HTTPRequest: validHTTPReq}}, wantErr: true, @@ -118,6 +125,7 @@ func TestSession_Authenticate(t *testing.T) { sessionService: &authtest.FakeUserAuthTokenService{LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { return validToken, nil }}, + authInfoService: &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{}}, }, args: args{r: &authn.Request{HTTPRequest: validHTTPReq}}, wantID: &authn.Identity{ @@ -130,6 +138,48 @@ func TestSession_Authenticate(t *testing.T) { }, wantErr: false, }, + { + name: "should set authID and authenticated by for externally authenticated user", + fields: fields{ + sessionService: &authtest.FakeUserAuthTokenService{LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return validToken, nil + }}, + authInfoService: &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{AuthId: "1", AuthModule: "oauth_azuread"}}, + }, + args: args{r: &authn.Request{HTTPRequest: validHTTPReq}}, + wantID: &authn.Identity{ + ID: "user:1", + AuthID: "1", + AuthenticatedBy: "oauth_azuread", + SessionToken: validToken, + + ClientParams: authn.ClientParams{ + SyncPermissions: true, + FetchSyncedUser: true, + }, + }, + wantErr: false, + }, + { + name: "should not set authID and authenticated by when no auth info exists for user", + fields: fields{ + sessionService: &authtest.FakeUserAuthTokenService{LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return validToken, nil + }}, + authInfoService: &authinfotest.FakeService{ExpectedError: user.ErrUserNotFound}, + }, + args: args{r: &authn.Request{HTTPRequest: validHTTPReq}}, + wantID: &authn.Identity{ + ID: "user:1", + SessionToken: validToken, + + ClientParams: authn.ClientParams{ + SyncPermissions: true, + FetchSyncedUser: true, + }, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -137,7 +187,7 @@ func TestSession_Authenticate(t *testing.T) { cfg.LoginCookieName = cookieName cfg.TokenRotationIntervalMinutes = 10 cfg.LoginMaxLifetime = 20 * time.Second - s := ProvideSession(cfg, tt.fields.sessionService) + s := ProvideSession(cfg, tt.fields.sessionService, tt.fields.authInfoService) got, err := s.Authenticate(context.Background(), tt.args.r) require.True(t, (err != nil) == tt.wantErr, err) diff --git a/pkg/services/authn/identity.go b/pkg/services/authn/identity.go index 54da56849e6..d0d31f7286a 100644 --- a/pkg/services/authn/identity.go +++ b/pkg/services/authn/identity.go @@ -102,6 +102,10 @@ func (i *Identity) GetNamespacedID() (namespace string, identifier string) { return split[0], split[1] } +func (i *Identity) GetAuthID() string { + return i.AuthID +} + func (i *Identity) GetAuthenticatedBy() string { return i.AuthenticatedBy } @@ -228,6 +232,7 @@ func (i *Identity) SignedInUser() *user.SignedInUser { Login: i.Login, Name: i.Name, Email: i.Email, + AuthID: i.AuthID, AuthenticatedBy: i.AuthenticatedBy, IsGrafanaAdmin: i.GetIsGrafanaAdmin(), IsAnonymous: namespace == NamespaceAnonymous, diff --git a/pkg/services/user/identity.go b/pkg/services/user/identity.go index 3a25d7f9f18..3d9d2fab649 100644 --- a/pkg/services/user/identity.go +++ b/pkg/services/user/identity.go @@ -14,15 +14,18 @@ const ( ) type SignedInUser struct { - UserID int64 `xorm:"user_id"` - UserUID string `xorm:"user_uid"` - OrgID int64 `xorm:"org_id"` - OrgName string - OrgRole roletype.RoleType - Login string - Name string - Email string - EmailVerified bool + UserID int64 `xorm:"user_id"` + UserUID string `xorm:"user_uid"` + OrgID int64 `xorm:"org_id"` + OrgName string + OrgRole roletype.RoleType + Login string + Name string + Email string + EmailVerified bool + // AuthID will be set if user signed in using external method + AuthID string + // AuthenticatedBy be set if user signed in using external method AuthenticatedBy string ApiKeyID int64 `xorm:"api_key_id"` IsServiceAccount bool `xorm:"is_service_account"` @@ -222,6 +225,14 @@ func (u *SignedInUser) GetNamespacedID() (string, string) { return parts[0], parts[1] } +func (u *SignedInUser) GetAuthID() string { + return u.AuthID +} + +func (u *SignedInUser) GetAuthenticatedBy() string { + return u.AuthenticatedBy +} + func (u *SignedInUser) IsAuthenticatedBy(providers ...string) bool { for _, p := range providers { if u.AuthenticatedBy == p { @@ -252,11 +263,6 @@ func (u *SignedInUser) GetDisplayName() string { return u.NameOrFallback() } -// DEPRECATEAD: Returns the authentication method used -func (u *SignedInUser) GetAuthenticatedBy() string { - return u.AuthenticatedBy -} - func (u *SignedInUser) GetIDToken() string { return u.IDToken } diff --git a/pkg/tests/web/index_view_test.go b/pkg/tests/web/index_view_test.go index 9668b94fca7..6ea42b35135 100644 --- a/pkg/tests/web/index_view_test.go +++ b/pkg/tests/web/index_view_test.go @@ -1,6 +1,7 @@ package web import ( + "bytes" "context" "encoding/json" "fmt" @@ -40,7 +41,7 @@ func TestIntegrationIndexView(t *testing.T) { addr, _ := testinfra.StartGrafana(t, grafDir, cfgPath) // nolint:bodyclose - resp, html := makeRequest(t, addr, "", "") + resp, html := makeRequest(t, addr, nil) assert.Regexp(t, `script-src 'self' 'unsafe-eval' 'unsafe-inline' 'strict-dynamic' 'nonce-[^']+';object-src 'none';font-src 'self';style-src 'self' 'unsafe-inline' blob:;img-src \* data:;base-uri 'self';connect-src 'self' grafana.com ws://localhost:3000/ wss://localhost:3000/;manifest-src 'self';media-src 'none';form-action 'self';`, resp.Header.Get("Content-Security-Policy")) assert.Regexp(t, `