mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
[MM-54456] Fix potential read after write issue when loading license (#24524)
* Fix potential read after write issue when loading license * Use upsert
This commit is contained in:
parent
c53b5f7b2b
commit
b4a47803e6
@ -5,6 +5,7 @@ package platform
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -17,6 +18,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/jobs"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
||||
"github.com/mattermost/mattermost/server/v8/einterfaces"
|
||||
)
|
||||
@ -95,7 +97,7 @@ func (ps *PlatformService) LoadLicense() {
|
||||
}
|
||||
}
|
||||
|
||||
record, nErr := ps.Store.License().Get(licenseId)
|
||||
record, nErr := ps.Store.License().Get(sqlstore.WithMaster(context.Background()), licenseId)
|
||||
if nErr != nil {
|
||||
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
|
||||
ps.SetLicense(nil)
|
||||
@ -166,7 +168,7 @@ func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *mo
|
||||
record.Id = license.Id
|
||||
record.Bytes = string(licenseBytes)
|
||||
|
||||
_, nErr := ps.Store.License().Save(record)
|
||||
nErr := ps.Store.License().Save(record)
|
||||
if nErr != nil {
|
||||
ps.RemoveLicense()
|
||||
var appErr *model.AppError
|
||||
|
@ -5179,7 +5179,7 @@ func (s *OpenTracingLayerJobStore) UpdateStatusOptimistically(id string, current
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
func (s *OpenTracingLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Get")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
@ -5188,7 +5188,7 @@ func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, err
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, err := s.LicenseStore.Get(id)
|
||||
result, err := s.LicenseStore.Get(ctx, id)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
@ -5215,7 +5215,7 @@ func (s *OpenTracingLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
|
||||
func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) error {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Save")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
@ -5224,13 +5224,13 @@ func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*mode
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, err := s.LicenseStore.Save(license)
|
||||
err := s.LicenseStore.Save(license)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
|
||||
return result, err
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {
|
||||
|
@ -5857,11 +5857,11 @@ func (s *RetryLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
func (s *RetryLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, err := s.LicenseStore.Get(id)
|
||||
result, err := s.LicenseStore.Get(ctx, id)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
@ -5899,21 +5899,21 @@ func (s *RetryLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
|
||||
func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) error {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, err := s.LicenseStore.Save(license)
|
||||
err := s.LicenseStore.Save(license)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
return nil
|
||||
}
|
||||
if !isRepeatableError(err) {
|
||||
return result, err
|
||||
return err
|
||||
}
|
||||
tries++
|
||||
if tries >= 3 {
|
||||
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
|
||||
return result, err
|
||||
return err
|
||||
}
|
||||
timepkg.Sleep(100 * timepkg.Millisecond)
|
||||
}
|
||||
|
@ -4,6 +4,8 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
sq "github.com/mattermost/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@ -23,46 +25,41 @@ func newSqlLicenseStore(sqlStore *SqlStore) store.LicenseStore {
|
||||
|
||||
// Save validates and stores the license instance in the database. The Id
|
||||
// and Bytes fields are mandatory. The Bytes field is limited to a maximum
|
||||
// of 10000 bytes. If the license ID matches an existing license in the
|
||||
// database it returns the license stored in the database. If not, it saves the
|
||||
// new database and returns the created license with the CreateAt field
|
||||
// updated.
|
||||
func (ls SqlLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
|
||||
// of 10000 bytes. Provided license is saved only if missing.
|
||||
func (ls SqlLicenseStore) Save(license *model.LicenseRecord) error {
|
||||
license.PreSave()
|
||||
if err := license.IsValid(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
query := ls.getQueryBuilder().
|
||||
Select("Id, CreateAt, Bytes").
|
||||
From("Licenses").
|
||||
Where(sq.Eq{"Id": license.Id})
|
||||
Insert("Licenses").
|
||||
Columns("Id", "CreateAt", "Bytes").
|
||||
Values(license.Id, license.CreateAt, license.Bytes)
|
||||
|
||||
if ls.DriverName() == model.DatabaseDriverMysql {
|
||||
query = query.SuffixExpr(sq.Expr("ON DUPLICATE KEY UPDATE Id=Id"))
|
||||
} else {
|
||||
query = query.SuffixExpr(sq.Expr("ON CONFLICT (Id) DO NOTHING"))
|
||||
}
|
||||
|
||||
queryString, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "license_tosql")
|
||||
return errors.Wrap(err, "license_tosql")
|
||||
}
|
||||
var storedLicense model.LicenseRecord
|
||||
if err := ls.GetReplicaX().Get(&storedLicense, queryString, args...); err != nil {
|
||||
// Only insert if not exists
|
||||
query, args, err := ls.getQueryBuilder().
|
||||
Insert("Licenses").
|
||||
Columns("Id", "CreateAt", "Bytes").
|
||||
Values(license.Id, license.CreateAt, license.Bytes).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "license_record_tosql")
|
||||
}
|
||||
if _, err := ls.GetMasterX().Exec(query, args...); err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to get License with licenseId=%s", license.Id)
|
||||
}
|
||||
return license, nil
|
||||
|
||||
if _, err := ls.GetMasterX().Exec(queryString, args...); err != nil {
|
||||
return errors.Wrapf(err, "failed to insert License with licenseId=%s", license.Id)
|
||||
}
|
||||
return &storedLicense, nil
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Get obtains the license with the provided id parameter from the database.
|
||||
// If the license doesn't exist it returns a model.AppError with
|
||||
// http.StatusNotFound in the StatusCode field.
|
||||
func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
func (ls SqlLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
|
||||
query := ls.getQueryBuilder().
|
||||
Select("Id, CreateAt, Bytes").
|
||||
From("Licenses").
|
||||
@ -74,7 +71,7 @@ func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
}
|
||||
|
||||
license := &model.LicenseRecord{}
|
||||
if err := ls.GetReplicaX().Get(license, queryString, args...); err != nil {
|
||||
if err := ls.DBXFromContext(ctx).Get(license, queryString, args...); err != nil {
|
||||
return nil, store.NewErrNotFound("License", id)
|
||||
}
|
||||
return license, nil
|
||||
|
@ -638,8 +638,8 @@ type PreferenceStore interface {
|
||||
}
|
||||
|
||||
type LicenseStore interface {
|
||||
Save(license *model.LicenseRecord) (*model.LicenseRecord, error)
|
||||
Get(id string) (*model.LicenseRecord, error)
|
||||
Save(license *model.LicenseRecord) error
|
||||
Get(ctx context.Context, id string) (*model.LicenseRecord, error)
|
||||
GetAll() ([]*model.LicenseRecord, error)
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
package storetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -22,15 +23,15 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) {
|
||||
l1.Id = model.NewId()
|
||||
l1.Bytes = "junk"
|
||||
|
||||
_, err := ss.License().Save(&l1)
|
||||
err := ss.License().Save(&l1)
|
||||
require.NoError(t, err, "couldn't save license record")
|
||||
|
||||
_, err = ss.License().Save(&l1)
|
||||
err = ss.License().Save(&l1)
|
||||
require.NoError(t, err, "shouldn't fail on trying to save existing license record")
|
||||
|
||||
l1.Id = ""
|
||||
|
||||
_, err = ss.License().Save(&l1)
|
||||
err = ss.License().Save(&l1)
|
||||
require.Error(t, err, "should fail on invalid license")
|
||||
}
|
||||
|
||||
@ -39,14 +40,14 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) {
|
||||
l1.Id = model.NewId()
|
||||
l1.Bytes = "junk"
|
||||
|
||||
_, err := ss.License().Save(&l1)
|
||||
err := ss.License().Save(&l1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := ss.License().Get(l1.Id)
|
||||
record, err := ss.License().Get(context.Background(), l1.Id)
|
||||
require.NoError(t, err, "couldn't get license")
|
||||
|
||||
require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match")
|
||||
|
||||
_, err = ss.License().Get("missing")
|
||||
_, err = ss.License().Get(context.Background(), "missing")
|
||||
require.Error(t, err, "should fail on get license")
|
||||
}
|
||||
|
@ -5,6 +5,8 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
model "github.com/mattermost/mattermost/server/public/model"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
@ -14,25 +16,25 @@ type LicenseStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Get provides a mock function with given fields: id
|
||||
func (_m *LicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
ret := _m.Called(id)
|
||||
// Get provides a mock function with given fields: ctx, id
|
||||
func (_m *LicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
var r0 *model.LicenseRecord
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) (*model.LicenseRecord, error)); ok {
|
||||
return rf(id)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*model.LicenseRecord, error)); ok {
|
||||
return rf(ctx, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string) *model.LicenseRecord); ok {
|
||||
r0 = rf(id)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *model.LicenseRecord); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.LicenseRecord)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(id)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@ -67,29 +69,17 @@ func (_m *LicenseStore) GetAll() ([]*model.LicenseRecord, error) {
|
||||
}
|
||||
|
||||
// Save provides a mock function with given fields: license
|
||||
func (_m *LicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
|
||||
func (_m *LicenseStore) Save(license *model.LicenseRecord) error {
|
||||
ret := _m.Called(license)
|
||||
|
||||
var r0 *model.LicenseRecord
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) (*model.LicenseRecord, error)); ok {
|
||||
return rf(license)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) *model.LicenseRecord); ok {
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) error); ok {
|
||||
r0 = rf(license)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.LicenseRecord)
|
||||
}
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(*model.LicenseRecord) error); ok {
|
||||
r1 = rf(license)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
return r0
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewLicenseStore interface {
|
||||
|
@ -4705,10 +4705,10 @@ func (s *TimerLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
|
||||
func (s *TimerLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
|
||||
start := time.Now()
|
||||
|
||||
result, err := s.LicenseStore.Get(id)
|
||||
result, err := s.LicenseStore.Get(ctx, id)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
@ -4737,10 +4737,10 @@ func (s *TimerLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
|
||||
func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) error {
|
||||
start := time.Now()
|
||||
|
||||
result, err := s.LicenseStore.Save(license)
|
||||
err := s.LicenseStore.Save(license)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
@ -4750,7 +4750,7 @@ func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.Lice
|
||||
}
|
||||
s.Root.Metrics.ObserveStoreMethodDuration("LicenseStore.Save", success, elapsed)
|
||||
}
|
||||
return result, err
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *TimerLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user