[MM-54456] Fix potential read after write issue when loading license (#24524)

* Fix potential read after write issue when loading license

* Use upsert
This commit is contained in:
Claudio Costa 2023-09-13 14:47:12 -06:00 committed by GitHub
parent c53b5f7b2b
commit b4a47803e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 71 additions and 81 deletions

View File

@ -5,6 +5,7 @@ package platform
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -17,6 +18,7 @@ import (
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/jobs" "github.com/mattermost/mattermost/server/v8/channels/jobs"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
"github.com/mattermost/mattermost/server/v8/channels/utils" "github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/einterfaces" "github.com/mattermost/mattermost/server/v8/einterfaces"
) )
@ -95,7 +97,7 @@ func (ps *PlatformService) LoadLicense() {
} }
} }
record, nErr := ps.Store.License().Get(licenseId) record, nErr := ps.Store.License().Get(sqlstore.WithMaster(context.Background()), licenseId)
if nErr != nil { if nErr != nil {
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr)) ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.SetLicense(nil) ps.SetLicense(nil)
@ -166,7 +168,7 @@ func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *mo
record.Id = license.Id record.Id = license.Id
record.Bytes = string(licenseBytes) record.Bytes = string(licenseBytes)
_, nErr := ps.Store.License().Save(record) nErr := ps.Store.License().Save(record)
if nErr != nil { if nErr != nil {
ps.RemoveLicense() ps.RemoveLicense()
var appErr *model.AppError var appErr *model.AppError

View File

@ -5179,7 +5179,7 @@ func (s *OpenTracingLayerJobStore) UpdateStatusOptimistically(id string, current
return result, err return result, err
} }
func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { func (s *OpenTracingLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
origCtx := s.Root.Store.Context() origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Get") span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Get")
s.Root.Store.SetContext(newCtx) s.Root.Store.SetContext(newCtx)
@ -5188,7 +5188,7 @@ func (s *OpenTracingLayerLicenseStore) Get(id string) (*model.LicenseRecord, err
}() }()
defer span.Finish() defer span.Finish()
result, err := s.LicenseStore.Get(id) result, err := s.LicenseStore.Get(ctx, id)
if err != nil { if err != nil {
span.LogFields(spanlog.Error(err)) span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true) ext.Error.Set(span, true)
@ -5215,7 +5215,7 @@ func (s *OpenTracingLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error)
return result, err return result, err
} }
func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) error {
origCtx := s.Root.Store.Context() origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Save") span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "LicenseStore.Save")
s.Root.Store.SetContext(newCtx) s.Root.Store.SetContext(newCtx)
@ -5224,13 +5224,13 @@ func (s *OpenTracingLayerLicenseStore) Save(license *model.LicenseRecord) (*mode
}() }()
defer span.Finish() defer span.Finish()
result, err := s.LicenseStore.Save(license) err := s.LicenseStore.Save(license)
if err != nil { if err != nil {
span.LogFields(spanlog.Error(err)) span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true) ext.Error.Set(span, true)
} }
return result, err return err
} }
func (s *OpenTracingLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) { func (s *OpenTracingLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {

View File

@ -5857,11 +5857,11 @@ func (s *RetryLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
} }
func (s *RetryLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { func (s *RetryLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
tries := 0 tries := 0
for { for {
result, err := s.LicenseStore.Get(id) result, err := s.LicenseStore.Get(ctx, id)
if err == nil { if err == nil {
return result, nil return result, nil
} }
@ -5899,21 +5899,21 @@ func (s *RetryLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
} }
func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { func (s *RetryLayerLicenseStore) Save(license *model.LicenseRecord) error {
tries := 0 tries := 0
for { for {
result, err := s.LicenseStore.Save(license) err := s.LicenseStore.Save(license)
if err == nil { if err == nil {
return result, nil return nil
} }
if !isRepeatableError(err) { if !isRepeatableError(err) {
return result, err return err
} }
tries++ tries++
if tries >= 3 { if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
return result, err return err
} }
timepkg.Sleep(100 * timepkg.Millisecond) timepkg.Sleep(100 * timepkg.Millisecond)
} }

View File

@ -4,6 +4,8 @@
package sqlstore package sqlstore
import ( import (
"context"
sq "github.com/mattermost/squirrel" sq "github.com/mattermost/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -23,46 +25,41 @@ func newSqlLicenseStore(sqlStore *SqlStore) store.LicenseStore {
// Save validates and stores the license instance in the database. The Id // Save validates and stores the license instance in the database. The Id
// and Bytes fields are mandatory. The Bytes field is limited to a maximum // and Bytes fields are mandatory. The Bytes field is limited to a maximum
// of 10000 bytes. If the license ID matches an existing license in the // of 10000 bytes. Provided license is saved only if missing.
// database it returns the license stored in the database. If not, it saves the func (ls SqlLicenseStore) Save(license *model.LicenseRecord) error {
// new database and returns the created license with the CreateAt field
// updated.
func (ls SqlLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) {
license.PreSave() license.PreSave()
if err := license.IsValid(); err != nil { if err := license.IsValid(); err != nil {
return nil, err return err
} }
query := ls.getQueryBuilder(). query := ls.getQueryBuilder().
Select("Id, CreateAt, Bytes"). Insert("Licenses").
From("Licenses"). Columns("Id", "CreateAt", "Bytes").
Where(sq.Eq{"Id": license.Id}) Values(license.Id, license.CreateAt, license.Bytes)
if ls.DriverName() == model.DatabaseDriverMysql {
query = query.SuffixExpr(sq.Expr("ON DUPLICATE KEY UPDATE Id=Id"))
} else {
query = query.SuffixExpr(sq.Expr("ON CONFLICT (Id) DO NOTHING"))
}
queryString, args, err := query.ToSql() queryString, args, err := query.ToSql()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "license_tosql") return errors.Wrap(err, "license_tosql")
} }
var storedLicense model.LicenseRecord
if err := ls.GetReplicaX().Get(&storedLicense, queryString, args...); err != nil { if _, err := ls.GetMasterX().Exec(queryString, args...); err != nil {
// Only insert if not exists return errors.Wrapf(err, "failed to insert License with licenseId=%s", license.Id)
query, args, err := ls.getQueryBuilder().
Insert("Licenses").
Columns("Id", "CreateAt", "Bytes").
Values(license.Id, license.CreateAt, license.Bytes).
ToSql()
if err != nil {
return nil, errors.Wrap(err, "license_record_tosql")
}
if _, err := ls.GetMasterX().Exec(query, args...); err != nil {
return nil, errors.Wrapf(err, "failed to get License with licenseId=%s", license.Id)
}
return license, nil
} }
return &storedLicense, nil
return nil
} }
// Get obtains the license with the provided id parameter from the database. // Get obtains the license with the provided id parameter from the database.
// If the license doesn't exist it returns a model.AppError with // If the license doesn't exist it returns a model.AppError with
// http.StatusNotFound in the StatusCode field. // http.StatusNotFound in the StatusCode field.
func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) { func (ls SqlLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
query := ls.getQueryBuilder(). query := ls.getQueryBuilder().
Select("Id, CreateAt, Bytes"). Select("Id, CreateAt, Bytes").
From("Licenses"). From("Licenses").
@ -74,7 +71,7 @@ func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, error) {
} }
license := &model.LicenseRecord{} license := &model.LicenseRecord{}
if err := ls.GetReplicaX().Get(license, queryString, args...); err != nil { if err := ls.DBXFromContext(ctx).Get(license, queryString, args...); err != nil {
return nil, store.NewErrNotFound("License", id) return nil, store.NewErrNotFound("License", id)
} }
return license, nil return license, nil

View File

@ -638,8 +638,8 @@ type PreferenceStore interface {
} }
type LicenseStore interface { type LicenseStore interface {
Save(license *model.LicenseRecord) (*model.LicenseRecord, error) Save(license *model.LicenseRecord) error
Get(id string) (*model.LicenseRecord, error) Get(ctx context.Context, id string) (*model.LicenseRecord, error)
GetAll() ([]*model.LicenseRecord, error) GetAll() ([]*model.LicenseRecord, error)
} }

View File

@ -4,6 +4,7 @@
package storetest package storetest
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -22,15 +23,15 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) {
l1.Id = model.NewId() l1.Id = model.NewId()
l1.Bytes = "junk" l1.Bytes = "junk"
_, err := ss.License().Save(&l1) err := ss.License().Save(&l1)
require.NoError(t, err, "couldn't save license record") require.NoError(t, err, "couldn't save license record")
_, err = ss.License().Save(&l1) err = ss.License().Save(&l1)
require.NoError(t, err, "shouldn't fail on trying to save existing license record") require.NoError(t, err, "shouldn't fail on trying to save existing license record")
l1.Id = "" l1.Id = ""
_, err = ss.License().Save(&l1) err = ss.License().Save(&l1)
require.Error(t, err, "should fail on invalid license") require.Error(t, err, "should fail on invalid license")
} }
@ -39,14 +40,14 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) {
l1.Id = model.NewId() l1.Id = model.NewId()
l1.Bytes = "junk" l1.Bytes = "junk"
_, err := ss.License().Save(&l1) err := ss.License().Save(&l1)
require.NoError(t, err) require.NoError(t, err)
record, err := ss.License().Get(l1.Id) record, err := ss.License().Get(context.Background(), l1.Id)
require.NoError(t, err, "couldn't get license") require.NoError(t, err, "couldn't get license")
require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match") require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match")
_, err = ss.License().Get("missing") _, err = ss.License().Get(context.Background(), "missing")
require.Error(t, err, "should fail on get license") require.Error(t, err, "should fail on get license")
} }

View File

@ -5,6 +5,8 @@
package mocks package mocks
import ( import (
context "context"
model "github.com/mattermost/mattermost/server/public/model" model "github.com/mattermost/mattermost/server/public/model"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
@ -14,25 +16,25 @@ type LicenseStore struct {
mock.Mock mock.Mock
} }
// Get provides a mock function with given fields: id // Get provides a mock function with given fields: ctx, id
func (_m *LicenseStore) Get(id string) (*model.LicenseRecord, error) { func (_m *LicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
ret := _m.Called(id) ret := _m.Called(ctx, id)
var r0 *model.LicenseRecord var r0 *model.LicenseRecord
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(string) (*model.LicenseRecord, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, string) (*model.LicenseRecord, error)); ok {
return rf(id) return rf(ctx, id)
} }
if rf, ok := ret.Get(0).(func(string) *model.LicenseRecord); ok { if rf, ok := ret.Get(0).(func(context.Context, string) *model.LicenseRecord); ok {
r0 = rf(id) r0 = rf(ctx, id)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.LicenseRecord) r0 = ret.Get(0).(*model.LicenseRecord)
} }
} }
if rf, ok := ret.Get(1).(func(string) error); ok { if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(id) r1 = rf(ctx, id)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -67,29 +69,17 @@ func (_m *LicenseStore) GetAll() ([]*model.LicenseRecord, error) {
} }
// Save provides a mock function with given fields: license // Save provides a mock function with given fields: license
func (_m *LicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { func (_m *LicenseStore) Save(license *model.LicenseRecord) error {
ret := _m.Called(license) ret := _m.Called(license)
var r0 *model.LicenseRecord var r0 error
var r1 error if rf, ok := ret.Get(0).(func(*model.LicenseRecord) error); ok {
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) (*model.LicenseRecord, error)); ok {
return rf(license)
}
if rf, ok := ret.Get(0).(func(*model.LicenseRecord) *model.LicenseRecord); ok {
r0 = rf(license) r0 = rf(license)
} else { } else {
if ret.Get(0) != nil { r0 = ret.Error(0)
r0 = ret.Get(0).(*model.LicenseRecord)
}
} }
if rf, ok := ret.Get(1).(func(*model.LicenseRecord) error); ok { return r0
r1 = rf(license)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
type mockConstructorTestingTNewLicenseStore interface { type mockConstructorTestingTNewLicenseStore interface {

View File

@ -4705,10 +4705,10 @@ func (s *TimerLayerJobStore) UpdateStatusOptimistically(id string, currentStatus
return result, err return result, err
} }
func (s *TimerLayerLicenseStore) Get(id string) (*model.LicenseRecord, error) { func (s *TimerLayerLicenseStore) Get(ctx context.Context, id string) (*model.LicenseRecord, error) {
start := time.Now() start := time.Now()
result, err := s.LicenseStore.Get(id) result, err := s.LicenseStore.Get(ctx, id)
elapsed := float64(time.Since(start)) / float64(time.Second) elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil { if s.Root.Metrics != nil {
@ -4737,10 +4737,10 @@ func (s *TimerLayerLicenseStore) GetAll() ([]*model.LicenseRecord, error) {
return result, err return result, err
} }
func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, error) { func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) error {
start := time.Now() start := time.Now()
result, err := s.LicenseStore.Save(license) err := s.LicenseStore.Save(license)
elapsed := float64(time.Since(start)) / float64(time.Second) elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil { if s.Root.Metrics != nil {
@ -4750,7 +4750,7 @@ func (s *TimerLayerLicenseStore) Save(license *model.LicenseRecord) (*model.Lice
} }
s.Root.Metrics.ObserveStoreMethodDuration("LicenseStore.Save", success, elapsed) s.Root.Metrics.ObserveStoreMethodDuration("LicenseStore.Save", success, elapsed)
} }
return result, err return err
} }
func (s *TimerLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) { func (s *TimerLayerLinkMetadataStore) Get(url string, timestamp int64) (*model.LinkMetadata, error) {