MM-53879: Fix recursive loading of license (#24200)

There were multiple problems with loading of a license.

1. It was called from inside app/server.go and app/platform/service.go. The first one wasn't really needed anymore, so we remove it.

2. To make loading of a license work across a cluster, the license load action was attached along with the `InvalidateAllCachesSkipSend` method. But the problem with that was that it would even get called in the caller node as well, putting it in a recursive loop.

```
LoadLicense -> SaveLicense -> InvalidateAllCaches -> InvalidateAllCachesSkipSend -> LoadLicense
```

To fix this, we create a dedicated loadLicense cluster event and move it away from the `InvalidateAllCachesSkipSend` method. And then from the caller side, we just trigger this action.

3. We also remove the first call to check license expiration which would load the license again. This is unnecessary because if the license is expired, server wouldn't start at all.

While here, we also make some other improvements like removing unnecessary goroutine spawning while publishing websocket events. They are already handled asynchronously, so there is no need
to create a goroutine for that.

We also remove

```
ps.ReloadConfig()
ps.InvalidateAllCaches()
```

from requestTrialLicense as they are already called from inside `*PlatformService.SaveLicense`.

And lastly, we remove the `*model.AppError` return from `*PlatformService.InvalidateAllCaches` because there was nothing to return at all.

https://mattermost.atlassian.net/browse/MM-53879

```release-note
Fix several issues with loading of a license
```
This commit is contained in:
Agniva De Sarker 2023-08-21 20:17:16 +05:30 committed by GitHub
parent 95b76e42ad
commit dd73c2af0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 31 additions and 165 deletions

View File

@ -317,11 +317,7 @@ func invalidateCaches(c *Context, w http.ResponseWriter, r *http.Request) {
return return
} }
appErr := c.App.Srv().InvalidateAllCaches() c.App.Srv().InvalidateAllCaches()
if appErr != nil {
c.Err = appErr
return
}
auditRec.Success() auditRec.Success()

View File

@ -3902,8 +3902,7 @@ func TestLoginWithLag(t *testing.T) {
_, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) _, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password)
require.NoError(t, err) require.NoError(t, err)
appErr = th.App.Srv().InvalidateAllCaches() th.App.Srv().InvalidateAllCaches()
require.Nil(t, appErr)
session, appErr := th.App.GetSession(th.Client.AuthToken) session, appErr := th.App.GetSession(th.Client.AuthToken)
require.Nil(t, appErr) require.Nil(t, appErr)

View File

@ -137,8 +137,8 @@ func (a *App) GetClusterStatus() []*model.ClusterInfo {
return infos return infos
} }
func (s *Server) InvalidateAllCaches() *model.AppError { func (s *Server) InvalidateAllCaches() {
return s.platform.InvalidateAllCaches() s.platform.InvalidateAllCaches()
} }
func (s *Server) InvalidateAllCachesSkipSend() { func (s *Server) InvalidateAllCachesSkipSend() {

View File

@ -131,10 +131,6 @@ func (s *Server) License() *model.License {
return s.platform.License() return s.platform.License()
} }
func (s *Server) LoadLicense() {
s.platform.LoadLicense()
}
func (s *Server) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { func (s *Server) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) {
return s.platform.SaveLicense(licenseBytes) return s.platform.SaveLicense(licenseBytes)
} }

View File

@ -1,109 +0,0 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package app
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
)
func TestLoadLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
th.App.Srv().LoadLicense()
require.Nil(t, th.App.Srv().License(), "shouldn't have a valid license")
}
func TestSaveLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
b1 := []byte("junk")
_, err := th.App.Srv().SaveLicense(b1)
require.NotNil(t, err, "shouldn't have saved license")
}
func TestRemoveLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
err := th.App.Srv().RemoveLicense()
require.Nil(t, err, "should have removed license")
}
func TestSetLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
l1 := &model.License{}
l1.Features = &model.Features{}
l1.Customer = &model.Customer{}
l1.StartsAt = model.GetMillis() - 1000
l1.ExpiresAt = model.GetMillis() + 100000
ok := th.App.Srv().SetLicense(l1)
require.True(t, ok, "license should have worked")
l3 := &model.License{}
l3.Features = &model.Features{}
l3.Customer = &model.Customer{}
l3.StartsAt = model.GetMillis() + 10000
l3.ExpiresAt = model.GetMillis() + 100000
ok = th.App.Srv().SetLicense(l3)
require.True(t, ok, "license should have passed")
}
func TestGetSanitizedClientLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
setLicense(th, nil)
m := th.App.Srv().GetSanitizedClientLicense()
_, ok := m["Name"]
assert.False(t, ok)
_, ok = m["SkuName"]
assert.False(t, ok)
}
func TestGenerateRenewalToken(t *testing.T) {
th := Setup(t)
defer th.TearDown()
t.Run("renewal token generated correctly", func(t *testing.T) {
setLicense(th, nil)
token, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration)
require.Nil(t, appErr)
require.NotEmpty(t, token)
})
t.Run("return error if there is no active license", func(t *testing.T) {
th.App.Srv().SetLicense(nil)
_, appErr := th.App.Srv().GenerateRenewalToken(JWTDefaultTokenExpiration)
require.NotNil(t, appErr)
})
}
func setLicense(th *TestHelper, customer *model.Customer) {
l1 := &model.License{}
l1.Features = &model.Features{}
if customer != nil {
l1.Customer = customer
} else {
l1.Customer = &model.Customer{}
l1.Customer.Name = "TestName"
l1.Customer.Email = "test@example.com"
}
l1.SkuName = "SKU NAME"
l1.SkuShortName = "SKU SHORT NAME"
l1.StartsAt = model.GetMillis() - 1000
l1.ExpiresAt = model.GetMillis() + 100000
th.App.Srv().SetLicense(l1)
}

View File

@ -98,8 +98,7 @@ func TestSendNotifications(t *testing.T) {
_, appErr = th.App.UpdateActive(th.Context, th.BasicUser2, false) _, appErr = th.App.UpdateActive(th.Context, th.BasicUser2, false)
require.Nil(t, appErr) require.Nil(t, appErr)
appErr = th.App.Srv().InvalidateAllCaches() th.App.Srv().InvalidateAllCaches()
require.Nil(t, appErr)
post3, appErr := th.App.CreatePostMissingChannel(th.Context, &model.Post{ post3, appErr := th.App.CreatePostMissingChannel(th.Context, &model.Post{
UserId: th.BasicUser.Id, UserId: th.BasicUser.Id,

View File

@ -113,9 +113,7 @@ func (a *App) DeleteOAuthApp(appID string) *model.AppError {
return model.NewAppError("DeleteOAuthApp", "app.oauth.delete_app.app_error", nil, "", http.StatusInternalServerError).Wrap(err) return model.NewAppError("DeleteOAuthApp", "app.oauth.delete_app.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
} }
if err := a.Srv().InvalidateAllCaches(); err != nil { a.Srv().InvalidateAllCaches()
mlog.Warn("error in invalidating cache", mlog.Err(err))
}
return nil return nil
} }

View File

@ -17,6 +17,7 @@ func (ps *PlatformService) RegisterClusterHandlers() {
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventPublish, ps.ClusterPublishHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventPublish, ps.ClusterPublishHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventUpdateStatus, ps.ClusterUpdateStatusHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventUpdateStatus, ps.ClusterUpdateStatusHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateAllCaches, ps.ClusterInvalidateAllCachesHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateAllCaches, ps.ClusterInvalidateAllCachesHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventLoadLicense, ps.LoadLicenseClusterHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelMembersNotifyProps, ps.clusterInvalidateCacheForChannelMembersNotifyPropHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelMembersNotifyProps, ps.clusterInvalidateCacheForChannelMembersNotifyPropHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelByName, ps.clusterInvalidateCacheForChannelByNameHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForChannelByName, ps.clusterInvalidateCacheForChannelByNameHandler)
ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForUser, ps.clusterInvalidateCacheForUserHandler) ps.clusterIFace.RegisterClusterMessageHandler(model.ClusterEventInvalidateCacheForUser, ps.clusterInvalidateCacheForUserHandler)
@ -154,10 +155,27 @@ func (ps *PlatformService) InvalidateAllCachesSkipSend() {
ps.Store.Webhook().ClearCaches() ps.Store.Webhook().ClearCaches()
linkCache.Purge() linkCache.Purge()
ps.LoadLicense()
} }
func (ps *PlatformService) InvalidateAllCaches() *model.AppError { func (ps *PlatformService) LoadLicenseClusterHandler(_ *model.ClusterMessage) {
ps.loadLicense()
}
func (ps *PlatformService) TriggerLoadLicense() {
ps.loadLicense()
if ps.clusterIFace != nil {
msg := &model.ClusterMessage{
Event: model.ClusterEventLoadLicense,
SendType: model.ClusterSendReliable,
WaitForAllToSend: true,
}
ps.clusterIFace.SendClusterMessage(msg)
}
}
func (ps *PlatformService) InvalidateAllCaches() {
ps.InvalidateAllCachesSkipSend() ps.InvalidateAllCachesSkipSend()
if ps.clusterIFace != nil { if ps.clusterIFace != nil {
@ -170,6 +188,4 @@ func (ps *PlatformService) InvalidateAllCaches() *model.AppError {
ps.clusterIFace.SendClusterMessage(msg) ps.clusterIFace.SendClusterMessage(msg)
} }
return nil
} }

View File

@ -46,7 +46,7 @@ func (ps *PlatformService) License() *model.License {
return ps.licenseValue.Load() return ps.licenseValue.Load()
} }
func (ps *PlatformService) LoadLicense() { func (ps *PlatformService) loadLicense() {
// ENV var overrides all other sources of license. // ENV var overrides all other sources of license.
licenseStr := os.Getenv(LicenseEnv) licenseStr := os.Getenv(LicenseEnv)
if licenseStr != "" { if licenseStr != "" {
@ -326,9 +326,6 @@ func (ps *PlatformService) RequestTrialLicense(trialRequest *model.TrialLicenseR
return err return err
} }
ps.ReloadConfig()
ps.InvalidateAllCaches()
return nil return nil
} }

View File

@ -12,14 +12,6 @@ import (
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
) )
func TestLoadLicense(t *testing.T) {
th := Setup(t)
defer th.TearDown()
th.Service.LoadLicense()
require.Nil(t, th.Service.License(), "shouldn't have a valid license")
}
func TestSaveLicense(t *testing.T) { func TestSaveLicense(t *testing.T) {
th := Setup(t) th := Setup(t)
defer th.TearDown() defer th.TearDown()

View File

@ -292,7 +292,7 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) {
// Step 7: Init License // Step 7: Init License
if model.BuildEnterpriseReady == "true" { if model.BuildEnterpriseReady == "true" {
ps.LoadLicense() ps.TriggerLoadLicense()
} }
// Step 8: Init Metrics Server depends on step 6 (store) and 7 (license) // Step 8: Init Metrics Server depends on step 6 (store) and 7 (license)
@ -353,9 +353,7 @@ func (ps *PlatformService) Start() error {
message := model.NewWebSocketEvent(model.WebsocketEventConfigChanged, "", "", "", nil, "") message := model.NewWebSocketEvent(model.WebsocketEventConfigChanged, "", "", "", nil, "")
message.Add("config", ps.ClientConfigWithComputed()) message.Add("config", ps.ClientConfigWithComputed())
ps.Go(func() { ps.Publish(message)
ps.Publish(message)
})
if err := ps.ReconfigureLogger(); err != nil { if err := ps.ReconfigureLogger(); err != nil {
mlog.Error("Error re-configuring logging after config change", mlog.Err(err)) mlog.Error("Error re-configuring logging after config change", mlog.Err(err))
@ -368,9 +366,7 @@ func (ps *PlatformService) Start() error {
message := model.NewWebSocketEvent(model.WebsocketEventLicenseChanged, "", "", "", nil, "") message := model.NewWebSocketEvent(model.WebsocketEventLicenseChanged, "", "", "", nil, "")
message.Add("license", ps.GetSanitizedClientLicense()) message.Add("license", ps.GetSanitizedClientLicense())
ps.Go(func() { ps.Publish(message)
ps.Publish(message)
})
}) })
return nil return nil

View File

@ -208,7 +208,6 @@ func NewServer(options ...Option) (*Server, error) {
// Depends on step 1 (s.Platform must be non-nil) // Depends on step 1 (s.Platform must be non-nil)
s.initEnterprise() s.initEnterprise()
// Needed to run before loading license.
s.userService, err = users.New(users.ServiceConfig{ s.userService, err = users.New(users.ServiceConfig{
UserStore: s.Store().User(), UserStore: s.Store().User(),
SessionStore: s.Store().Session(), SessionStore: s.Store().Session(),
@ -222,11 +221,6 @@ func NewServer(options ...Option) (*Server, error) {
return nil, errors.Wrapf(err, "unable to create users service") return nil, errors.Wrapf(err, "unable to create users service")
} }
if model.BuildEnterpriseReady == "true" {
// Dependent on user service
s.LoadLicense()
}
s.licenseWrapper = &licenseWrapper{ s.licenseWrapper = &licenseWrapper{
srv: s, srv: s,
} }
@ -1383,8 +1377,6 @@ func (s *Server) sendLicenseUpForRenewalEmail(users map[string]*model.User, lice
} }
func (s *Server) doLicenseExpirationCheck() { func (s *Server) doLicenseExpirationCheck() {
s.LoadLicense()
// This takes care of a rare edge case reported here https://mattermost.atlassian.net/browse/MM-40962 // This takes care of a rare edge case reported here https://mattermost.atlassian.net/browse/MM-40962
// To reproduce that case locally, attach a license to a server that was started with enterprise enabled // To reproduce that case locally, attach a license to a server that was started with enterprise enabled
// Then restart using BUILD_ENTERPRISE=false make restart-server to enter Team Edition // Then restart using BUILD_ENTERPRISE=false make restart-server to enter Team Edition
@ -1394,7 +1386,6 @@ func (s *Server) doLicenseExpirationCheck() {
} }
license := s.License() license := s.License()
if license == nil { if license == nil {
mlog.Debug("License cannot be found.") mlog.Debug("License cannot be found.")
return return

View File

@ -49,10 +49,6 @@ func initDBCommandContext(configDSN string, readOnlyConfigStore bool, options ..
a := app.New(app.ServerConnector(s.Channels())) a := app.New(app.ServerConnector(s.Channels()))
if model.BuildEnterpriseReady == "true" {
a.Srv().LoadLicense()
}
return a, nil return a, nil
} }

View File

@ -41,8 +41,6 @@ func jobserverCmdF(command *cobra.Command, args []string) error {
} }
defer a.Srv().Shutdown() defer a.Srv().Shutdown()
a.Srv().LoadLicense()
// Run jobs // Run jobs
mlog.Info("Starting Mattermost job server") mlog.Info("Starting Mattermost job server")
defer mlog.Info("Stopped Mattermost job server") defer mlog.Info("Stopped Mattermost job server")

View File

@ -9,6 +9,7 @@ const (
ClusterEventPublish ClusterEvent = "publish" ClusterEventPublish ClusterEvent = "publish"
ClusterEventUpdateStatus ClusterEvent = "update_status" ClusterEventUpdateStatus ClusterEvent = "update_status"
ClusterEventInvalidateAllCaches ClusterEvent = "inv_all_caches" ClusterEventInvalidateAllCaches ClusterEvent = "inv_all_caches"
ClusterEventLoadLicense ClusterEvent = "load_license"
ClusterEventInvalidateCacheForReactions ClusterEvent = "inv_reactions" ClusterEventInvalidateCacheForReactions ClusterEvent = "inv_reactions"
ClusterEventInvalidateCacheForChannelMembersNotifyProps ClusterEvent = "inv_channel_members_notify_props" ClusterEventInvalidateCacheForChannelMembersNotifyProps ClusterEvent = "inv_channel_members_notify_props"
ClusterEventInvalidateCacheForChannelByName ClusterEvent = "inv_channel_name" ClusterEventInvalidateCacheForChannelByName ClusterEvent = "inv_channel_name"