[MM-54456] Fix potential read after write issue when loading license (#24524)

* Fix potential read after write issue when loading license

* Use upsert
This commit is contained in:
Claudio Costa 2023-09-13 14:47:12 -06:00 committed by GitHub
parent c53b5f7b2b
commit b4a47803e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 71 additions and 81 deletions

View File

@ -5,6 +5,7 @@ package platform
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
@ -17,6 +18,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/jobs"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
"github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/einterfaces"
)
@ -95,7 +97,7 @@ func (ps *PlatformService) LoadLicense() {
}
}
record, nErr := ps.Store.License().Get(licenseId)
record, nErr := ps.Store.License().Get(sqlstore.WithMaster(context.Background()), licenseId)
if nErr != nil {
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.SetLicense(nil)
@ -166,7 +168,7 @@ func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *mo
record.Id = license.Id
record.Bytes = string(licenseBytes)
_, nErr := ps.Store.License().Save(record)
nErr := ps.Store.License().Save(record)
if nErr != nil {
ps.RemoveLicense()
var appErr *model.AppError

View File

@ -5179,7 +5179,7 @@ func (s *OpenTracingLayerJobStore) UpdateStatusOptimistically(id string, current
return result, err
}
func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
func (s *OpenTracingLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Get")
s.Root.Store.SetContext(newCtx)
@ -5188,7 +5188,7 @@ func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, err
}()
defer span.Finish()
result, err := s.LicenseStore.Get(id)
result, err := s.LicenseStore.Get(ctx, id)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
@ -5215,7 +5215,7 @@ func (s *OpenTracingLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error)
return result, err
}
func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) error {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Save")
s.Root.Store.SetContext(newCtx)
@ -5224,13 +5224,13 @@ func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*mode
}()
defer span.Finish()
result, err := s.LicenseStore.Save(license)
err := s.LicenseStore.Save(license)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
return result, err
return err
}
func (s *OpenTracingLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {

View File

@ -5857,11 +5857,11 @@ func (s *RetryLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
}
func (s *RetryLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
func (s *RetryLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
tries := 0
for {
result, err := s.LicenseStore.Get(id)
result, err := s.LicenseStore.Get(ctx, id)
if err == nil {
return result, nil
}
@ -5899,21 +5899,21 @@ func (s *RetryLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
}
func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) error {
tries := 0
for {
result, err := s.LicenseStore.Save(license)
err := s.LicenseStore.Save(license)
if err == nil {
return result, nil
return nil
}
if !isRepeatableError(err) {
return result, err
return err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
return result, err
return err
}
timepkg.Sleep(100 * timepkg.Millisecond)
}

View File

@ -4,6 +4,8 @@
package sqlstore
import (
"context"
sq "github.com/mattermost/squirrel"
"github.com/pkg/errors"
@ -23,46 +25,41 @@ func newSqlLicenseStore(sqlStore *SqlStore) store.LicenseStore {
// Save validates and stores the license instance in the database. The Id
// and Bytes fields are mandatory. The Bytes field is limited to a maximum
// of 10000 bytes. If the license ID matches an existing license in the
// database it returns the license stored in the database. If not, it saves the
// new database and returns the created license with the CreateAt field
// updated.
func (ls SqlLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
// of 10000 bytes. Provided license is saved only if missing.
func (ls SqlLicenseStore) Save(license *model.LicenseRecord) error {
license.PreSave()
if err := license.IsValid(); err != nil {
return nil, err
return err
}
query := ls.getQueryBuilder().
Select("Id, CreateAt, Bytes").
From("Licenses").
Where(sq.Eq{"Id": license.Id})
Insert("Licenses").
Columns("Id", "CreateAt", "Bytes").
Values(license.Id, license.CreateAt, license.Bytes)
if ls.DriverName() == model.DatabaseDriverMysql {
query = query.SuffixExpr(sq.Expr("ON DUPLICATE KEY UPDATE Id=Id"))
} else {
query = query.SuffixExpr(sq.Expr("ON CONFLICT (Id) DO NOTHING"))
}
queryString, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "license_tosql")
return errors.Wrap(err, "license_tosql")
}
var storedLicense model.LicenseRecord
if err := ls.GetReplicaX().Get(&storedLicense, queryString, args...); err != nil {
// Only insert if not exists
query, args, err := ls.getQueryBuilder().
Insert("Licenses").
Columns("Id", "CreateAt", "Bytes").
Values(license.Id, license.CreateAt, license.Bytes).
ToSql()
if err != nil {
return nil, errors.Wrap(err, "license_record_tosql")
}
if _, err := ls.GetMasterX().Exec(query, args...); err != nil {
return nil, errors.Wrapf(err, "failed to get License with licenseId=%s", license.Id)
}
return license, nil
if _, err := ls.GetMasterX().Exec(queryString, args...); err != nil {
return errors.Wrapf(err, "failed to insert License with licenseId=%s", license.Id)
}
return &storedLicense, nil
return nil
}
// Get obtains the license with the provided id parameter from the database.
// If the license doesn't exist it returns a model.AppError with
// http.StatusNotFound in the StatusCode field.
func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) {
func (ls SqlLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
query := ls.getQueryBuilder().
Select("Id, CreateAt, Bytes").
From("Licenses").
@ -74,7 +71,7 @@ func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) {
}
license := &model.LicenseRecord{}
if err := ls.GetReplicaX().Get(license, queryString, args...); err != nil {
if err := ls.DBXFromContext(ctx).Get(license, queryString, args...); err != nil {
return nil, store.NewErrNotFound("License", id)
}
return license, nil

View File

@ -638,8 +638,8 @@ type PreferenceStore interface {
}
type LicenseStore interface {
Save(license *model.LicenseRecord) (*model.LicenseRecord, error)
Get(id string) (*model.LicenseRecord, error)
Save(license *model.LicenseRecord) error
Get(ctx context.Context, id string) (*model.LicenseRecord, error)
GetAll() ([]*model.LicenseRecord, error)
}

View File

@ -4,6 +4,7 @@
package storetest
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -22,15 +23,15 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) {
l1.Id = model.NewId()
l1.Bytes = "junk"
_, err := ss.License().Save(&l1)
err := ss.License().Save(&l1)
require.NoError(t, err, "couldn't save license record")
_, err = ss.License().Save(&l1)
err = ss.License().Save(&l1)
require.NoError(t, err, "shouldn't fail on trying to save existing license record")
l1.Id = ""
_, err = ss.License().Save(&l1)
err = ss.License().Save(&l1)
require.Error(t, err, "should fail on invalid license")
}
@ -39,14 +40,14 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) {
l1.Id = model.NewId()
l1.Bytes = "junk"
_, err := ss.License().Save(&l1)
err := ss.License().Save(&l1)
require.NoError(t, err)
record, err := ss.License().Get(l1.Id)
record, err := ss.License().Get(context.Background(), l1.Id)
require.NoError(t, err, "couldn't get license")
require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match")
_, err = ss.License().Get("missing")
_, err = ss.License().Get(context.Background(), "missing")
require.Error(t, err, "should fail on get license")
}

View File

@ -5,6 +5,8 @@
package mocks
import (
context "context"
model "github.com/mattermost/mattermost/server/public/model"
mock "github.com/stretchr/testify/mock"
)
@ -14,25 +16,25 @@ type LicenseStore struct {
mock.Mock
}
// Get provides a mock function with given fields: id
func (_m *LicenseStore) Get(id string) (*model.LicenseRecord, error) {
ret := _m.Called(id)
// Get provides a mock function with given fields: ctx, id
func (_m *LicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
ret := _m.Called(ctx, id)
var r0 *model.LicenseRecord
var r1 error
if rf, ok := ret.Get(0).(func(string) (*model.LicenseRecord, error)); ok {
return rf(id)
if rf, ok := ret.Get(0).(func(context.Context, string) (*model.LicenseRecord, error)); ok {
return rf(ctx, id)
}
if rf, ok := ret.Get(0).(func(string) *model.LicenseRecord); ok {
r0 = rf(id)
if rf, ok := ret.Get(0).(func(context.Context, string) *model.LicenseRecord); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.LicenseRecord)
}
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(id)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@ -67,29 +69,17 @@ func (_m *LicenseStore) GetAll() ([]*model.LicenseRecord, error) {
}
// Save provides a mock function with given fields: license
func (_m *LicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
func (_m *LicenseStore) Save(license *model.LicenseRecord) error {
ret := _m.Called(license)
var r0 *model.LicenseRecord
var r1 error
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) (*model.LicenseRecord, error)); ok {
return rf(license)
}
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) *model.LicenseRecord); ok {
var r0 error
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) error); ok {
r0 = rf(license)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.LicenseRecord)
}
r0 = ret.Error(0)
}
if rf, ok := ret.Get(1).(func(*model.LicenseRecord) error); ok {
r1 = rf(license)
} else {
r1 = ret.Error(1)
}
return r0, r1
return r0
}
type mockConstructorTestingTNewLicenseStore interface {

View File

@ -4705,10 +4705,10 @@ func (s *TimerLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
return result, err
}
func (s *TimerLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) {
func (s *TimerLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
start := time.Now()
result, err := s.LicenseStore.Get(id)
result, err := s.LicenseStore.Get(ctx, id)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@ -4737,10 +4737,10 @@ func (s *TimerLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
return result, err
}
func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) error {
start := time.Now()
result, err := s.LicenseStore.Save(license)
err := s.LicenseStore.Save(license)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@ -4750,7 +4750,7 @@ func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.Lice
}
s.Root.Metrics.ObserveStoreMethodDuration("LicenseStore.Save", success, elapsed)
}
return result, err
return err
}
func (s *TimerLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {