From dd73c2af0f5ee8a5a2cfa201bcf44e8b08a124c1 Mon Sep 17 00:00:00 2001 From: Agniva De Sarker Date: Mon, 21 Aug 2023 20:17:16 +0530 Subject: [PATCH] MM-53879: Fix recursive loading of license (#24200) There were multiple problems with loading of a license. 1. It was called from inside app/server.go and app/platform/service.go. The first one wasn't really needed anymore, so we remove it. 2. To make loading of a license work across a cluster, the license load action was attached along with the `InvalidateAllCachesSkipSend` method. But the problem with that was that it would even get called in the caller node as well, putting it in a recursive loop. ``` LoadLicense -> SaveLicense -> InvalidateAllCaches -> InvalidateAllCachesSkipSend -> LoadLicense ``` To fix this, we create a dedicated loadLicense cluster event and move it away from the `InvalidateAllCachesSkipSend` method. And then from the caller side, we just trigger this action. 3. We also remove the first call to check license expiration which would load the license again. This is unnecessary because if the license is expired, server wouldn't start at all. While here, we also make some other improvements like removing unnecessary goroutine spawning while publishing websocket events. They are already handled asynchronously, so there is no need to create a goroutine for that. We also remove ``` ps.ReloadConfig() ps.InvalidateAllCaches() ``` from requestTrialLicense as they are already called from inside `*PlatformService.SaveLicense`. And lastly, we remove the `*model.AppError` return from `*PlatformService.InvalidateAllCaches` because there was nothing to return at all. https://mattermost.atlassian.net/browse/MM-53879 ```release-note Fix several issues with loading of a license ``` --- server/channels/api4/system.go | 6 +- server/channels/api4/user_test.go | 3 +- server/channels/app/admin.go | 4 +- server/channels/app/license.go | 4 - server/channels/app/license_test.go | 109 ------------------ server/channels/app/notification_test.go | 3 +- server/channels/app/oauth.go | 4 +- .../channels/app/platform/cluster_handlers.go | 24 +++- server/channels/app/platform/license.go | 5 +- server/channels/app/platform/license_test.go | 8 -- server/channels/app/platform/service.go | 10 +- server/channels/app/server.go | 9 -- server/cmd/mattermost/commands/init.go | 4 - server/cmd/mattermost/commands/jobserver.go | 2 - server/public/model/cluster_message.go | 1 + 15 files changed, 31 insertions(+), 165 deletions(-) delete mode 100644 server/channels/app/license_test.go diff --git a/server/channels/api4/system.go b/server/channels/api4/system.go index 273818888c..9c8ad6fc65 100644 --- a/server/channels/api4/system.go +++ b/server/channels/api4/system.go @@ -317,11 +317,7 @@ func invalidateCaches(c *Context, w http.ResponseWriter, r *http.Request) { return } - appErr := c.App.Srv().InvalidateAllCaches() - if appErr != nil { - c.Err = appErr - return - } + c.App.Srv().InvalidateAllCaches() auditRec.Success() diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index aa10a94355..df2f19b850 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -3902,8 +3902,7 @@ func TestLoginWithLag(t *testing.T) { _, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) require.NoError(t, err) - appErr = th.App.Srv().InvalidateAllCaches() - require.Nil(t, appErr) + th.App.Srv().InvalidateAllCaches() session, appErr := th.App.GetSession(th.Client.AuthToken) require.Nil(t, appErr) diff --git a/server/channels/app/admin.go b/server/channels/app/admin.go index 7e1678cafc..1f970a6c9b 100644 --- a/server/channels/app/admin.go +++ b/server/channels/app/admin.go @@ -137,8 +137,8 @@ func (a *App) GetClusterStatus() []*model.ClusterInfo { return infos } -func (s *Server) InvalidateAllCaches() *model.AppError { - return s.platform.InvalidateAllCaches() +func (s *Server) InvalidateAllCaches() { + s.platform.InvalidateAllCaches() } func (s *Server) InvalidateAllCachesSkipSend() { diff --git a/server/channels/app/license.go b/server/channels/app/license.go index c2f879e475..33efbdd5e3 100644 --- a/server/channels/app/license.go +++ b/server/channels/app/license.go @@ -131,10 +131,6 @@ func (s *Server) License() *model.License { return s.platform.License() } -func (s *Server) LoadLicense() { - s.platform.LoadLicense() -} - func (s *Server) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { return s.platform.SaveLicense(licenseBytes) } diff --git a/server/channels/app/license_test.go b/server/channels/app/license_test.go deleted file mode 100644 index 26b4b58b70..0000000000 --- a/server/channels/app/license_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. -// See LICENSE.txt for license information. - -package app - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/mattermost/mattermost/server/public/model" -) - -func TestLoadLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - th.App.Srv().LoadLicense() - require.Nil(t, th.App.Srv().License(), "shouldn't have a valid license") -} - -func TestSaveLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - b1 := []byte("junk") - - _, err := th.App.Srv().SaveLicense(b1) - require.NotNil(t, err, "shouldn't have saved license") -} - -func TestRemoveLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - err := th.App.Srv().RemoveLicense() - require.Nil(t, err, "should have removed license") -} - -func TestSetLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - l1 := &model.License{} - l1.Features = &model.Features{} - l1.Customer = &model.Customer{} - l1.StartsAt = model.GetMillis() - 1000 - l1.ExpiresAt = model.GetMillis() + 100000 - ok := th.App.Srv().SetLicense(l1) - require.True(t, ok, "license should have worked") - - l3 := &model.License{} - l3.Features = &model.Features{} - l3.Customer = &model.Customer{} - l3.StartsAt = model.GetMillis() + 10000 - l3.ExpiresAt = model.GetMillis() + 100000 - ok = th.App.Srv().SetLicense(l3) - require.True(t, ok, "license should have passed") -} - -func TestGetSanitizedClientLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - setLicense(th, nil) - - m := th.App.Srv().GetSanitizedClientLicense() - - _, ok := m["Name"] - assert.False(t, ok) - _, ok = m["SkuName"] - assert.False(t, ok) -} - -func TestGenerateRenewalToken(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - t.Run("renewal token generated correctly", func(t *testing.T) { - setLicense(th, nil) - token, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration) - require.Nil(t, appErr) - require.NotEmpty(t, token) - }) - - t.Run("return error if there is no active license", func(t *testing.T) { - th.App.Srv().SetLicense(nil) - _, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration) - require.NotNil(t, appErr) - }) -} - -func setLicense(th *TestHelper, customer *model.Customer) { - l1 := &model.License{} - l1.Features = &model.Features{} - if customer != nil { - l1.Customer = customer - } else { - l1.Customer = &model.Customer{} - l1.Customer.Name = "TestName" - l1.Customer.Email = "test@example.com" - } - l1.SkuName = "SKU NAME" - l1.SkuShortName = "SKU SHORT NAME" - l1.StartsAt = model.GetMillis() - 1000 - l1.ExpiresAt = model.GetMillis() + 100000 - th.App.Srv().SetLicense(l1) -} diff --git a/server/channels/app/notification_test.go b/server/channels/app/notification_test.go index 9c4b6efc1e..621fa7703f 100644 --- a/server/channels/app/notification_test.go +++ b/server/channels/app/notification_test.go @@ -98,8 +98,7 @@ func TestSendNotifications(t *testing.T) { _, appErr = th.App.UpdateActive(th.Context, th.BasicUser2, false) require.Nil(t, appErr) - appErr = th.App.Srv().InvalidateAllCaches() - require.Nil(t, appErr) + th.App.Srv().InvalidateAllCaches() post3, appErr := th.App.CreatePostMissingChannel(th.Context, &model.Post{ UserId: th.BasicUser.Id, diff --git a/server/channels/app/oauth.go b/server/channels/app/oauth.go index 2a186b2fc7..1c82da9560 100644 --- a/server/channels/app/oauth.go +++ b/server/channels/app/oauth.go @@ -113,9 +113,7 @@ func (a *App) DeleteOAuthApp(appID string) *model.AppError { return model.NewAppError("DeleteOAuthApp", "app.oauth.delete_app.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } - if err := a.Srv().InvalidateAllCaches(); err != nil { - mlog.Warn("error in invalidating cache", mlog.Err(err)) - } + a.Srv().InvalidateAllCaches() return nil } diff --git a/server/channels/app/platform/cluster_handlers.go b/server/channels/app/platform/cluster_handlers.go index a45809085b..2d50fd1516 100644 --- a/server/channels/app/platform/cluster_handlers.go +++ b/server/channels/app/platform/cluster_handlers.go @@ -17,6 +17,7 @@ func (ps *PlatformService) RegisterClusterHandlers() { ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventPublish, ps.ClusterPublishHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventUpdateStatus, ps.ClusterUpdateStatusHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateAllCaches, ps.ClusterInvalidateAllCachesHandler) + ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventLoadLicense, ps.LoadLicenseClusterHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelMembersNotifyProps, ps.clusterInvalidateCacheForChannelMembersNotifyPropHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelByName, ps.clusterInvalidateCacheForChannelByNameHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForUser, ps.clusterInvalidateCacheForUserHandler) @@ -154,10 +155,27 @@ func (ps *PlatformService) InvalidateAllCachesSkipSend() { ps.Store.Webhook().ClearCaches() linkCache.Purge() - ps.LoadLicense() } -func (ps *PlatformService) InvalidateAllCaches() *model.AppError { +func (ps *PlatformService) LoadLicenseClusterHandler(_ *model.ClusterMessage) { + ps.loadLicense() +} + +func (ps *PlatformService) TriggerLoadLicense() { + ps.loadLicense() + + if ps.clusterIFace != nil { + msg := &model.ClusterMessage{ + Event: model.ClusterEventLoadLicense, + SendType: model.ClusterSendReliable, + WaitForAllToSend: true, + } + + ps.clusterIFace.SendClusterMessage(msg) + } +} + +func (ps *PlatformService) InvalidateAllCaches() { ps.InvalidateAllCachesSkipSend() if ps.clusterIFace != nil { @@ -170,6 +188,4 @@ func (ps *PlatformService) InvalidateAllCaches() *model.AppError { ps.clusterIFace.SendClusterMessage(msg) } - - return nil } diff --git a/server/channels/app/platform/license.go b/server/channels/app/platform/license.go index d3a1b0ecb5..dbe7afb27f 100644 --- a/server/channels/app/platform/license.go +++ b/server/channels/app/platform/license.go @@ -46,7 +46,7 @@ func (ps *PlatformService) License() *model.License { return ps.licenseValue.Load() } -func (ps *PlatformService) LoadLicense() { +func (ps *PlatformService) loadLicense() { // ENV var overrides all other sources of license. licenseStr := os.Getenv(LicenseEnv) if licenseStr != "" { @@ -326,9 +326,6 @@ func (ps *PlatformService) RequestTrialLicense(trialRequest *model.TrialLicenseR return err } - ps.ReloadConfig() - ps.InvalidateAllCaches() - return nil } diff --git a/server/channels/app/platform/license_test.go b/server/channels/app/platform/license_test.go index 3f1af98ad1..cbc15b507a 100644 --- a/server/channels/app/platform/license_test.go +++ b/server/channels/app/platform/license_test.go @@ -12,14 +12,6 @@ import ( "github.com/mattermost/mattermost/server/public/model" ) -func TestLoadLicense(t *testing.T) { - th := Setup(t) - defer th.TearDown() - - th.Service.LoadLicense() - require.Nil(t, th.Service.License(), "shouldn't have a valid license") -} - func TestSaveLicense(t *testing.T) { th := Setup(t) defer th.TearDown() diff --git a/server/channels/app/platform/service.go b/server/channels/app/platform/service.go index 8869f5ea40..9074e5258c 100644 --- a/server/channels/app/platform/service.go +++ b/server/channels/app/platform/service.go @@ -292,7 +292,7 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) { // Step 7: Init License if model.BuildEnterpriseReady == "true" { - ps.LoadLicense() + ps.TriggerLoadLicense() } // Step 8: Init Metrics Server depends on step 6 (store) and 7 (license) @@ -353,9 +353,7 @@ func (ps *PlatformService) Start() error { message := model.NewWebSocketEvent(model.WebsocketEventConfigChanged, "", "", "", nil, "") message.Add("config", ps.ClientConfigWithComputed()) - ps.Go(func() { - ps.Publish(message) - }) + ps.Publish(message) if err := ps.ReconfigureLogger(); err != nil { mlog.Error("Error re-configuring logging after config change", mlog.Err(err)) @@ -368,9 +366,7 @@ func (ps *PlatformService) Start() error { message := model.NewWebSocketEvent(model.WebsocketEventLicenseChanged, "", "", "", nil, "") message.Add("license", ps.GetSanitizedClientLicense()) - ps.Go(func() { - ps.Publish(message) - }) + ps.Publish(message) }) return nil diff --git a/server/channels/app/server.go b/server/channels/app/server.go index a45d7cc8b7..10e040d147 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -208,7 +208,6 @@ func NewServer(options ...Option) (*Server, error) { // Depends on step 1 (s.Platform must be non-nil) s.initEnterprise() - // Needed to run before loading license. s.userService, err = users.New(users.ServiceConfig{ UserStore: s.Store().User(), SessionStore: s.Store().Session(), @@ -222,11 +221,6 @@ func NewServer(options ...Option) (*Server, error) { return nil, errors.Wrapf(err, "unable to create users service") } - if model.BuildEnterpriseReady == "true" { - // Dependent on user service - s.LoadLicense() - } - s.licenseWrapper = &licenseWrapper{ srv: s, } @@ -1383,8 +1377,6 @@ func (s *Server) sendLicenseUpForRenewalEmail(users map[string]*model.User, lice } func (s *Server) doLicenseExpirationCheck() { - s.LoadLicense() - // This takes care of a rare edge case reported here https://mattermost.atlassian.net/browse/MM-40962 // To reproduce that case locally, attach a license to a server that was started with enterprise enabled // Then restart using BUILD_ENTERPRISE=false make restart-server to enter Team Edition @@ -1394,7 +1386,6 @@ func (s *Server) doLicenseExpirationCheck() { } license := s.License() - if license == nil { mlog.Debug("License cannot be found.") return diff --git a/server/cmd/mattermost/commands/init.go b/server/cmd/mattermost/commands/init.go index fc1e07b6a1..140624eb0c 100644 --- a/server/cmd/mattermost/commands/init.go +++ b/server/cmd/mattermost/commands/init.go @@ -49,10 +49,6 @@ func initDBCommandContext(configDSN string, readOnlyConfigStore bool, options .. a := app.New(app.ServerConnector(s.Channels())) - if model.BuildEnterpriseReady == "true" { - a.Srv().LoadLicense() - } - return a, nil } diff --git a/server/cmd/mattermost/commands/jobserver.go b/server/cmd/mattermost/commands/jobserver.go index 18a544b465..32895a1492 100644 --- a/server/cmd/mattermost/commands/jobserver.go +++ b/server/cmd/mattermost/commands/jobserver.go @@ -41,8 +41,6 @@ func jobserverCmdF(command *cobra.Command, args []string) error { } defer a.Srv().Shutdown() - a.Srv().LoadLicense() - // Run jobs mlog.Info("Starting Mattermost job server") defer mlog.Info("Stopped Mattermost job server") diff --git a/server/public/model/cluster_message.go b/server/public/model/cluster_message.go index 6ff912f9ac..19f5e4079e 100644 --- a/server/public/model/cluster_message.go +++ b/server/public/model/cluster_message.go @@ -9,6 +9,7 @@ const ( ClusterEventPublish ClusterEvent = "publish" ClusterEventUpdateStatus ClusterEvent = "update_status" ClusterEventInvalidateAllCaches ClusterEvent = "inv_all_caches" + ClusterEventLoadLicense ClusterEvent = "load_license" ClusterEventInvalidateCacheForReactions ClusterEvent = "inv_reactions" ClusterEventInvalidateCacheForChannelMembersNotifyProps ClusterEvent = "inv_channel_members_notify_props" ClusterEventInvalidateCacheForChannelByName ClusterEvent = "inv_channel_name"