diff --git a/server/channels/app/platform/license.go b/server/channels/app/platform/license.go index d3a1b0ecb5..d71348d9e9 100644 --- a/server/channels/app/platform/license.go +++ b/server/channels/app/platform/license.go @@ -5,6 +5,7 @@ package platform import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -17,6 +18,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/jobs" + "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" "github.com/mattermost/mattermost/server/v8/channels/utils" "github.com/mattermost/mattermost/server/v8/einterfaces" ) @@ -95,7 +97,7 @@ func (ps *PlatformService) LoadLicense() { } } - record, nErr := ps.Store.License().Get(licenseId) + record, nErr := ps.Store.License().Get(sqlstore.WithMaster(context.Background()), licenseId) if nErr != nil { ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr)) ps.SetLicense(nil) @@ -166,7 +168,7 @@ func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *mo record.Id = license.Id record.Bytes = string(licenseBytes) - _, nErr := ps.Store.License().Save(record) + nErr := ps.Store.License().Save(record) if nErr != nil { ps.RemoveLicense() var appErr *model.AppError diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go index 82f479076d..a34bb74b24 100644 --- a/server/channels/store/opentracinglayer/opentracinglayer.go +++ b/server/channels/store/opentracinglayer/opentracinglayer.go @@ -5179,7 +5179,7 @@ func (s *OpenTracingLayerJobStore) UpdateStatusOptimistically(id string, current return result, err } -func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { +func (s *OpenTracingLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Get") s.Root.Store.SetContext(newCtx) @@ -5188,7 +5188,7 @@ func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, err }() defer span.Finish() - result, err := s.LicenseStore.Get(id) + result, err := s.LicenseStore.Get(ctx, id) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -5215,7 +5215,7 @@ func (s *OpenTracingLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) return result, err } -func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { +func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) error { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Save") s.Root.Store.SetContext(newCtx) @@ -5224,13 +5224,13 @@ func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*mode }() defer span.Finish() - result, err := s.LicenseStore.Save(license) + err := s.LicenseStore.Save(license) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) } - return result, err + return err } func (s *OpenTracingLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) { diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index c9519324ab..8bbdf544fb 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -5857,11 +5857,11 @@ func (s *RetryLayerJobStore) UpdateStatusOptimistically(id string, currentStatus } -func (s *RetryLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { +func (s *RetryLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) { tries := 0 for { - result, err := s.LicenseStore.Get(id) + result, err := s.LicenseStore.Get(ctx, id) if err == nil { return result, nil } @@ -5899,21 +5899,21 @@ func (s *RetryLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) { } -func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { +func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) error { tries := 0 for { - result, err := s.LicenseStore.Save(license) + err := s.LicenseStore.Save(license) if err == nil { - return result, nil + return nil } if !isRepeatableError(err) { - return result, err + return err } tries++ if tries >= 3 { err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return result, err + return err } timepkg.Sleep(100 * timepkg.Millisecond) } diff --git a/server/channels/store/sqlstore/license_store.go b/server/channels/store/sqlstore/license_store.go index 545518224b..d58d7c27e5 100644 --- a/server/channels/store/sqlstore/license_store.go +++ b/server/channels/store/sqlstore/license_store.go @@ -4,6 +4,8 @@ package sqlstore import ( + "context" + sq "github.com/mattermost/squirrel" "github.com/pkg/errors" @@ -23,46 +25,41 @@ func newSqlLicenseStore(sqlStore *SqlStore) store.LicenseStore { // Save validates and stores the license instance in the database. The Id // and Bytes fields are mandatory. The Bytes field is limited to a maximum -// of 10000 bytes. If the license ID matches an existing license in the -// database it returns the license stored in the database. If not, it saves the -// new database and returns the created license with the CreateAt field -// updated. -func (ls SqlLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { +// of 10000 bytes. Provided license is saved only if missing. +func (ls SqlLicenseStore) Save(license *model.LicenseRecord) error { license.PreSave() if err := license.IsValid(); err != nil { - return nil, err + return err } + query := ls.getQueryBuilder(). - Select("Id, CreateAt, Bytes"). - From("Licenses"). - Where(sq.Eq{"Id": license.Id}) + Insert("Licenses"). + Columns("Id", "CreateAt", "Bytes"). + Values(license.Id, license.CreateAt, license.Bytes) + + if ls.DriverName() == model.DatabaseDriverMysql { + query = query.SuffixExpr(sq.Expr("ON DUPLICATE KEY UPDATE Id=Id")) + } else { + query = query.SuffixExpr(sq.Expr("ON CONFLICT (Id) DO NOTHING")) + } + queryString, args, err := query.ToSql() if err != nil { - return nil, errors.Wrap(err, "license_tosql") + return errors.Wrap(err, "license_tosql") } - var storedLicense model.LicenseRecord - if err := ls.GetReplicaX().Get(&storedLicense, queryString, args...); err != nil { - // Only insert if not exists - query, args, err := ls.getQueryBuilder(). - Insert("Licenses"). - Columns("Id", "CreateAt", "Bytes"). - Values(license.Id, license.CreateAt, license.Bytes). - ToSql() - if err != nil { - return nil, errors.Wrap(err, "license_record_tosql") - } - if _, err := ls.GetMasterX().Exec(query, args...); err != nil { - return nil, errors.Wrapf(err, "failed to get License with licenseId=%s", license.Id) - } - return license, nil + + if _, err := ls.GetMasterX().Exec(queryString, args...); err != nil { + return errors.Wrapf(err, "failed to insert License with licenseId=%s", license.Id) } - return &storedLicense, nil + + return nil + } // Get obtains the license with the provided id parameter from the database. // If the license doesn't exist it returns a model.AppError with // http.StatusNotFound in the StatusCode field. -func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) { +func (ls SqlLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) { query := ls.getQueryBuilder(). Select("Id, CreateAt, Bytes"). From("Licenses"). @@ -74,7 +71,7 @@ func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) { } license := &model.LicenseRecord{} - if err := ls.GetReplicaX().Get(license, queryString, args...); err != nil { + if err := ls.DBXFromContext(ctx).Get(license, queryString, args...); err != nil { return nil, store.NewErrNotFound("License", id) } return license, nil diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 0d48be54f5..a8897d46ff 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -638,8 +638,8 @@ type PreferenceStore interface { } type LicenseStore interface { - Save(license *model.LicenseRecord) (*model.LicenseRecord, error) - Get(id string) (*model.LicenseRecord, error) + Save(license *model.LicenseRecord) error + Get(ctx context.Context, id string) (*model.LicenseRecord, error) GetAll() ([]*model.LicenseRecord, error) } diff --git a/server/channels/store/storetest/license_store.go b/server/channels/store/storetest/license_store.go index 8705a0e127..e167ab8634 100644 --- a/server/channels/store/storetest/license_store.go +++ b/server/channels/store/storetest/license_store.go @@ -4,6 +4,7 @@ package storetest import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -22,15 +23,15 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) { l1.Id = model.NewId() l1.Bytes = "junk" - _, err := ss.License().Save(&l1) + err := ss.License().Save(&l1) require.NoError(t, err, "couldn't save license record") - _, err = ss.License().Save(&l1) + err = ss.License().Save(&l1) require.NoError(t, err, "shouldn't fail on trying to save existing license record") l1.Id = "" - _, err = ss.License().Save(&l1) + err = ss.License().Save(&l1) require.Error(t, err, "should fail on invalid license") } @@ -39,14 +40,14 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) { l1.Id = model.NewId() l1.Bytes = "junk" - _, err := ss.License().Save(&l1) + err := ss.License().Save(&l1) require.NoError(t, err) - record, err := ss.License().Get(l1.Id) + record, err := ss.License().Get(context.Background(), l1.Id) require.NoError(t, err, "couldn't get license") require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match") - _, err = ss.License().Get("missing") + _, err = ss.License().Get(context.Background(), "missing") require.Error(t, err, "should fail on get license") } diff --git a/server/channels/store/storetest/mocks/LicenseStore.go b/server/channels/store/storetest/mocks/LicenseStore.go index 155c2de967..04c9ddd764 100644 --- a/server/channels/store/storetest/mocks/LicenseStore.go +++ b/server/channels/store/storetest/mocks/LicenseStore.go @@ -5,6 +5,8 @@ package mocks import ( + context "context" + model "github.com/mattermost/mattermost/server/public/model" mock "github.com/stretchr/testify/mock" ) @@ -14,25 +16,25 @@ type LicenseStore struct { mock.Mock } -// Get provides a mock function with given fields: id -func (_m *LicenseStore) Get(id string) (*model.LicenseRecord, error) { - ret := _m.Called(id) +// Get provides a mock function with given fields: ctx, id +func (_m *LicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) { + ret := _m.Called(ctx, id) var r0 *model.LicenseRecord var r1 error - if rf, ok := ret.Get(0).(func(string) (*model.LicenseRecord, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) (*model.LicenseRecord, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(string) *model.LicenseRecord); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) *model.LicenseRecord); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.LicenseRecord) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -67,29 +69,17 @@ func (_m *LicenseStore) GetAll() ([]*model.LicenseRecord, error) { } // Save provides a mock function with given fields: license -func (_m *LicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { +func (_m *LicenseStore) Save(license *model.LicenseRecord) error { ret := _m.Called(license) - var r0 *model.LicenseRecord - var r1 error - if rf, ok := ret.Get(0).(func(*model.LicenseRecord) (*model.LicenseRecord, error)); ok { - return rf(license) - } - if rf, ok := ret.Get(0).(func(*model.LicenseRecord) *model.LicenseRecord); ok { + var r0 error + if rf, ok := ret.Get(0).(func(*model.LicenseRecord) error); ok { r0 = rf(license) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*model.LicenseRecord) - } + r0 = ret.Error(0) } - if rf, ok := ret.Get(1).(func(*model.LicenseRecord) error); ok { - r1 = rf(license) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } type mockConstructorTestingTNewLicenseStore interface { diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index c56996b3ff..162db771fb 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -4705,10 +4705,10 @@ func (s *TimerLayerJobStore) UpdateStatusOptimistically(id string, currentStatus return result, err } -func (s *TimerLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { +func (s *TimerLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) { start := time.Now() - result, err := s.LicenseStore.Get(id) + result, err := s.LicenseStore.Get(ctx, id) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -4737,10 +4737,10 @@ func (s *TimerLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) { return result, err } -func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { +func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) error { start := time.Now() - result, err := s.LicenseStore.Save(license) + err := s.LicenseStore.Save(license) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -4750,7 +4750,7 @@ func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.Lice } s.Root.Metrics.ObserveStoreMethodDuration("LicenseStore.Save", success, elapsed) } - return result, err + return err } func (s *TimerLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {