diff --git a/app/license.go b/app/license.go index ce50b1e2c8..2cc1ffd313 100644 --- a/app/license.go +++ b/app/license.go @@ -36,13 +36,14 @@ func (a *App) LoadLicense() { } } - if result := <-a.Srv.Store.License().Get(licenseId); result.Err == nil { - record := result.Data.(*model.LicenseRecord) - a.ValidateAndSetLicenseBytes([]byte(record.Bytes)) - mlog.Info("License key valid unlocking enterprise features.") - } else { + record, err := a.Srv.Store.License().Get(licenseId) + if err != nil { mlog.Info("License key from https://mattermost.com required to unlock enterprise features.") + return } + + a.ValidateAndSetLicenseBytes([]byte(record.Bytes)) + mlog.Info("License key valid unlocking enterprise features.") } func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { @@ -73,11 +74,11 @@ func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) record := &model.LicenseRecord{} record.Id = license.Id record.Bytes = string(licenseBytes) - rchan := a.Srv.Store.License().Save(record) - if result := <-rchan; result.Err != nil { + _, err := a.Srv.Store.License().Save(record) + if err != nil { a.RemoveLicense() - return nil, model.NewAppError("addLicense", "api.license.add_license.save.app_error", nil, "err="+result.Err.Error(), http.StatusInternalServerError) + return nil, model.NewAppError("addLicense", "api.license.add_license.save.app_error", nil, "err="+err.Error(), http.StatusInternalServerError) } sysVar := &model.System{} diff --git a/store/sqlstore/license_store.go b/store/sqlstore/license_store.go index 0a1293deeb..a86ba78626 100644 --- a/store/sqlstore/license_store.go +++ b/store/sqlstore/license_store.go @@ -29,32 +29,30 @@ func NewSqlLicenseStore(sqlStore SqlStore) store.LicenseStore { func (ls SqlLicenseStore) CreateIndexesIfNotExists() { } -func (ls SqlLicenseStore) Save(license *model.LicenseRecord) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - license.PreSave() - if result.Err = license.IsValid(); result.Err != nil { - return - } +func (ls SqlLicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, *model.AppError) { + license.PreSave() + if err := license.IsValid(); err != nil { + return nil, err + } + var storedLicense model.LicenseRecord + if err := ls.GetReplica().SelectOne(&storedLicense, "SELECT * FROM Licenses WHERE Id = :Id", map[string]interface{}{"Id": license.Id}); err != nil { // Only insert if not exists - if err := ls.GetReplica().SelectOne(&model.LicenseRecord{}, "SELECT * FROM Licenses WHERE Id = :Id", map[string]interface{}{"Id": license.Id}); err != nil { - if err := ls.GetMaster().Insert(license); err != nil { - result.Err = model.NewAppError("SqlLicenseStore.Save", "store.sql_license.save.app_error", nil, "license_id="+license.Id+", "+err.Error(), http.StatusInternalServerError) - } else { - result.Data = license - } + if err := ls.GetMaster().Insert(license); err != nil { + return nil, model.NewAppError("SqlLicenseStore.Save", "store.sql_license.save.app_error", nil, "license_id="+license.Id+", "+err.Error(), http.StatusInternalServerError) } - }) + return license, nil + } + return &storedLicense, nil } -func (ls SqlLicenseStore) Get(id string) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - if obj, err := ls.GetReplica().Get(model.LicenseRecord{}, id); err != nil { - result.Err = model.NewAppError("SqlLicenseStore.Get", "store.sql_license.get.app_error", nil, "license_id="+id+", "+err.Error(), http.StatusInternalServerError) - } else if obj == nil { - result.Err = model.NewAppError("SqlLicenseStore.Get", "store.sql_license.get.missing.app_error", nil, "license_id="+id, http.StatusNotFound) - } else { - result.Data = obj.(*model.LicenseRecord) - } - }) +func (ls SqlLicenseStore) Get(id string) (*model.LicenseRecord, *model.AppError) { + obj, err := ls.GetReplica().Get(model.LicenseRecord{}, id) + if err != nil { + return nil, model.NewAppError("SqlLicenseStore.Get", "store.sql_license.get.app_error", nil, "license_id="+id+", "+err.Error(), http.StatusInternalServerError) + } + if obj == nil { + return nil, model.NewAppError("SqlLicenseStore.Get", "store.sql_license.get.missing.app_error", nil, "license_id="+id, http.StatusNotFound) + } + return obj.(*model.LicenseRecord), nil } diff --git a/store/store.go b/store/store.go index b02cf5c0dc..16bf8fffc3 100644 --- a/store/store.go +++ b/store/store.go @@ -436,8 +436,8 @@ type PreferenceStore interface { } type LicenseStore interface { - Save(license *model.LicenseRecord) StoreChannel - Get(id string) StoreChannel + Save(license *model.LicenseRecord) (*model.LicenseRecord, *model.AppError) + Get(id string) (*model.LicenseRecord, *model.AppError) } type TokenStore interface { diff --git a/store/storetest/license_store.go b/store/storetest/license_store.go index 452d37e7b9..ab9722dd9b 100644 --- a/store/storetest/license_store.go +++ b/store/storetest/license_store.go @@ -8,6 +8,7 @@ import ( "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/store" + "github.com/stretchr/testify/require" ) func TestLicenseStore(t *testing.T, ss store.Store) { @@ -20,17 +21,17 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) { l1.Id = model.NewId() l1.Bytes = "junk" - if err := (<-ss.License().Save(&l1)).Err; err != nil { + if _, err := ss.License().Save(&l1); err != nil { t.Fatal("couldn't save license record", err) } - if err := (<-ss.License().Save(&l1)).Err; err != nil { + if _, err := ss.License().Save(&l1); err != nil { t.Fatal("shouldn't fail on trying to save existing license record", err) } l1.Id = "" - if err := (<-ss.License().Save(&l1)).Err; err == nil { + if _, err := ss.License().Save(&l1); err == nil { t.Fatal("should fail on invalid license", err) } } @@ -40,17 +41,18 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) { l1.Id = model.NewId() l1.Bytes = "junk" - store.Must(ss.License().Save(&l1)) + _, err := ss.License().Save(&l1) + require.Nil(t, err) - if r := <-ss.License().Get(l1.Id); r.Err != nil { - t.Fatal("couldn't get license", r.Err) + if record, err := ss.License().Get(l1.Id); err != nil { + t.Fatal("couldn't get license", err) } else { - if r.Data.(*model.LicenseRecord).Bytes != l1.Bytes { + if record.Bytes != l1.Bytes { t.Fatal("license bytes didn't match") } } - if err := (<-ss.License().Get("missing")).Err; err == nil { + if _, err := ss.License().Get("missing"); err == nil { t.Fatal("should fail on get license", err) } } diff --git a/store/storetest/mocks/LicenseStore.go b/store/storetest/mocks/LicenseStore.go index f00ebba78a..b769b39404 100644 --- a/store/storetest/mocks/LicenseStore.go +++ b/store/storetest/mocks/LicenseStore.go @@ -6,7 +6,6 @@ package mocks import mock "github.com/stretchr/testify/mock" import model "github.com/mattermost/mattermost-server/model" -import store "github.com/mattermost/mattermost-server/store" // LicenseStore is an autogenerated mock type for the LicenseStore type type LicenseStore struct { @@ -14,33 +13,51 @@ type LicenseStore struct { } // Get provides a mock function with given fields: id -func (_m *LicenseStore) Get(id string) store.StoreChannel { +func (_m *LicenseStore) Get(id string) (*model.LicenseRecord, *model.AppError) { ret := _m.Called(id) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(string) store.StoreChannel); ok { + var r0 *model.LicenseRecord + if rf, ok := ret.Get(0).(func(string) *model.LicenseRecord); ok { r0 = rf(id) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.LicenseRecord) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string) *model.AppError); ok { + r1 = rf(id) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // Save provides a mock function with given fields: license -func (_m *LicenseStore) Save(license *model.LicenseRecord) store.StoreChannel { +func (_m *LicenseStore) Save(license *model.LicenseRecord) (*model.LicenseRecord, *model.AppError) { ret := _m.Called(license) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(*model.LicenseRecord) store.StoreChannel); ok { + var r0 *model.LicenseRecord + if rf, ok := ret.Get(0).(func(*model.LicenseRecord) *model.LicenseRecord); ok { r0 = rf(license) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.LicenseRecord) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.LicenseRecord) *model.AppError); ok { + r1 = rf(license) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 }