[MM-56653] Improve license loading errors (#26050)

This commit is contained in:
Ben Schumacher 2024-04-05 16:59:19 +02:00 committed by GitHub
parent 1a9355b2eb
commit 71e26b8df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 84 deletions

View File

@ -109,7 +109,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license) licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes) licenseStr := string(licenseBytes)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
@ -144,7 +144,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, err := json.Marshal(license) licenseBytes, err := json.Marshal(license)
require.NoError(t, err) require.NoError(t, err)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseBytes)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseBytes), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
resp, err := th.SystemAdminClient.UploadLicenseFile(context.Background(), []byte("")) resp, err := th.SystemAdminClient.UploadLicenseFile(context.Background(), []byte(""))
@ -177,7 +177,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license) licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes) licenseStr := string(licenseBytes)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
@ -275,7 +275,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{} mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator() defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -306,7 +306,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{} mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator() defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -344,7 +344,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{} mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator() defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -405,7 +405,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{} mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator() defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -435,7 +435,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{} mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator() defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()

View File

@ -129,7 +129,7 @@ func (s *Server) SetLicense(license *model.License) bool {
return s.platform.SetLicense(license) return s.platform.SetLicense(license)
} }
func (s *Server) ValidateAndSetLicenseBytes(b []byte) bool { func (s *Server) ValidateAndSetLicenseBytes(b []byte) error {
return s.platform.ValidateAndSetLicenseBytes(b) return s.platform.ValidateAndSetLicenseBytes(b)
} }

View File

@ -54,9 +54,9 @@ func (ps *PlatformService) LoadLicense() {
// ENV var overrides all other sources of license. // ENV var overrides all other sources of license.
licenseStr := os.Getenv(LicenseEnv) licenseStr := os.Getenv(LicenseEnv)
if licenseStr != "" { if licenseStr != "" {
license, err := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr)) license, appErr := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr))
if err != nil { if appErr != nil {
ps.logger.Error("Failed to read license set in environment.", mlog.Err(err)) ps.logger.Error("Failed to read license set in environment.", mlog.Err(appErr))
return return
} }
@ -74,7 +74,9 @@ func (ps *PlatformService) LoadLicense() {
} }
} }
if ps.ValidateAndSetLicenseBytes([]byte(licenseStr)) { if err := ps.ValidateAndSetLicenseBytes([]byte(licenseStr)); err != nil {
ps.logger.Info("License key from ENV is invalid.", mlog.Err(err))
} else {
ps.logger.Info("License key from ENV is valid, unlocking enterprise features.") ps.logger.Info("License key from ENV is valid, unlocking enterprise features.")
} }
return return
@ -88,9 +90,10 @@ func (ps *PlatformService) LoadLicense() {
if !model.IsValidId(licenseId) { if !model.IsValidId(licenseId) {
// Lets attempt to load the file from disk since it was missing from the DB // Lets attempt to load the file from disk since it was missing from the DB
license, licenseBytes := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation) license, licenseBytes, err := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation)
if err != nil {
if license != nil { ps.logger.Warn("Failed to get license from disk", mlog.Err(err))
} else {
if _, err := ps.SaveLicense(licenseBytes); err != nil { if _, err := ps.SaveLicense(licenseBytes); err != nil {
ps.logger.Error("Failed to save license key loaded from disk.", mlog.Err(err)) ps.logger.Error("Failed to save license key loaded from disk.", mlog.Err(err))
} else { } else {
@ -101,19 +104,23 @@ func (ps *PlatformService) LoadLicense() {
record, nErr := ps.Store.License().Get(sqlstore.RequestContextWithMaster(c), licenseId) record, nErr := ps.Store.License().Get(sqlstore.RequestContextWithMaster(c), licenseId)
if nErr != nil { if nErr != nil {
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr)) ps.logger.Warn("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.SetLicense(nil) ps.SetLicense(nil)
return return
} }
ps.ValidateAndSetLicenseBytes([]byte(record.Bytes)) err := ps.ValidateAndSetLicenseBytes([]byte(record.Bytes))
ps.logger.Info("License key valid unlocking enterprise features.") if err != nil {
ps.logger.Info("License key is invalid.")
}
ps.logger.Info("License key is valid, unlocking enterprise features.")
} }
func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := utils.LicenseValidator.ValidateLicense(licenseBytes) licenseStr, err := utils.LicenseValidator.ValidateLicense(licenseBytes)
if !success { if err != nil {
return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest) return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
} }
var license model.License var license model.License
@ -231,19 +238,19 @@ func (ps *PlatformService) SetLicense(license *model.License) bool {
return false return false
} }
func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) bool { func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) error {
if success, licenseStr := utils.LicenseValidator.ValidateLicense(b); success { licenseStr, err := utils.LicenseValidator.ValidateLicense(b)
var license model.License if err != nil {
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { return errors.Wrap(err, "Failed to decode license from JSON")
ps.logger.Warn("Failed to decode license from JSON", mlog.Err(jsonErr))
return false
}
ps.SetLicense(&license)
return true
} }
ps.logger.Warn("No valid enterprise license found") var license model.License
return false if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return errors.Wrap(err, "Failed to decode license from JSON")
}
ps.SetLicense(&license)
return nil
} }
func (ps *PlatformService) SetClientLicense(m map[string]string) { func (ps *PlatformService) SetClientLicense(m map[string]string) {

View File

@ -11,6 +11,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -32,33 +33,32 @@ func init() {
type LicenseValidatorIface interface { type LicenseValidatorIface interface {
LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError)
ValidateLicense(signed []byte) (bool, string) ValidateLicense(signed []byte) (string, error)
} }
type LicenseValidatorImpl struct { type LicenseValidatorImpl struct {
} }
func (l *LicenseValidatorImpl) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) { func (l *LicenseValidatorImpl) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := l.ValidateLicense(licenseBytes) licenseStr, err := l.ValidateLicense(licenseBytes)
if !success { if err != nil {
return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest) return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
} }
var license model.License var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr) return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(err)
} }
return &license, nil return &license, nil
} }
func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) { func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (string, error) {
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(signed))) decoded := make([]byte, base64.StdEncoding.DecodedLen(len(signed)))
_, err := base64.StdEncoding.Decode(decoded, signed) _, err := base64.StdEncoding.Decode(decoded, signed)
if err != nil { if err != nil {
mlog.Error("Encountered error decoding license", mlog.Err(err)) return "", fmt.Errorf("encountered error decoding license: %w", err)
return false, ""
} }
// remove null terminator // remove null terminator
@ -67,8 +67,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
} }
if len(decoded) <= 256 { if len(decoded) <= 256 {
mlog.Error("Signed license not long enough") return "", fmt.Errorf("Signed license not long enough")
return false, ""
} }
plaintext := decoded[:len(decoded)-256] plaintext := decoded[:len(decoded)-256]
@ -85,8 +84,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
public, err := x509.ParsePKIXPublicKey(block.Bytes) public, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
mlog.Error("Encountered error signing license", mlog.Err(err)) return "", fmt.Errorf("Encountered error signing license: %w", err)
return false, ""
} }
rsaPublic := public.(*rsa.PublicKey) rsaPublic := public.(*rsa.PublicKey)
@ -97,37 +95,34 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
err = rsa.VerifyPKCS1v15(rsaPublic, crypto.SHA512, d, signature) err = rsa.VerifyPKCS1v15(rsaPublic, crypto.SHA512, d, signature)
if err != nil { if err != nil {
mlog.Error("Invalid signature", mlog.Err(err)) return "", fmt.Errorf("Invalid signature: %w", err)
return false, ""
} }
return true, string(plaintext) return string(plaintext), nil
} }
func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte) { func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte, error) {
fileName := GetLicenseFileLocation(location) fileName := GetLicenseFileLocation(location)
mlog.Info("License key has not been uploaded. Loading license key from disk.", mlog.String("filename", fileName))
if _, err := os.Stat(fileName); err != nil { if _, err := os.Stat(fileName); err != nil {
mlog.Debug("We could not find the license key in the database or on disk at", mlog.String("filename", fileName)) return nil, nil, fmt.Errorf("We could not find the license key on disk at %s: %w", fileName, err)
return nil, nil
} }
mlog.Info("License key has not been uploaded. Loading license key from disk at", mlog.String("filename", fileName))
licenseBytes := GetLicenseFileFromDisk(fileName) licenseBytes := GetLicenseFileFromDisk(fileName)
success, licenseStr := LicenseValidator.ValidateLicense(licenseBytes) licenseStr, err := LicenseValidator.ValidateLicense(licenseBytes)
if !success { if err != nil {
mlog.Error("Found license key at %v but it appears to be invalid.", mlog.String("filename", fileName)) return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
return nil, nil
} }
var license model.License var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
mlog.Error("Failed to decode license from JSON", mlog.Err(jsonErr)) return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
return nil, nil
} }
return &license, licenseBytes return &license, licenseBytes, nil
} }
func GetLicenseFileFromDisk(fileName string) []byte { func GetLicenseFileFromDisk(fileName string) []byte {

View File

@ -19,12 +19,12 @@ var validTestLicense = []byte("eyJpZCI6InpvZ3c2NW44Z2lmajVkbHJoYThtYnUxcGl3Iiwia
func TestValidateLicense(t *testing.T) { func TestValidateLicense(t *testing.T) {
t.Run("should fail with junk data", func(t *testing.T) { t.Run("should fail with junk data", func(t *testing.T) {
b1 := []byte("junk") b1 := []byte("junk")
ok, _ := LicenseValidator.ValidateLicense(b1) _, err := LicenseValidator.ValidateLicense(b1)
require.False(t, ok, "should have failed - bad license") require.Error(t, err, "should have failed - bad license")
b2 := []byte("junkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunk") b2 := []byte("junkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunk")
ok, _ = LicenseValidator.ValidateLicense(b2) _, err = LicenseValidator.ValidateLicense(b2)
require.False(t, ok, "should have failed - bad license") require.Error(t, err, "should have failed - bad license")
}) })
t.Run("should not panic on shorter than expected input", func(t *testing.T) { t.Run("should not panic on shorter than expected input", func(t *testing.T) {
@ -42,8 +42,8 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close() err = encoder.Close()
require.NoError(t, err) require.NoError(t, err)
ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes()) str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok) require.Error(t, err)
require.Empty(t, str) require.Empty(t, str)
}) })
@ -61,8 +61,8 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close() err = encoder.Close()
require.NoError(t, err) require.NoError(t, err)
ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes()) str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok) require.Error(t, err)
require.Empty(t, str) require.Empty(t, str)
}) })
@ -70,8 +70,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest) os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT") defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(nil) str, err := LicenseValidator.ValidateLicense(nil)
require.False(t, ok) require.Error(t, err)
require.Empty(t, str) require.Empty(t, str)
}) })
@ -79,8 +79,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest) os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT") defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(validTestLicense) str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.True(t, ok) require.NoError(t, err)
require.NotEmpty(t, str) require.NotEmpty(t, str)
}) })
@ -88,8 +88,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentProduction) os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentProduction)
defer os.Unsetenv("MM_SERVICEENVIRONMENT") defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(validTestLicense) str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.False(t, ok) require.Error(t, err)
require.Empty(t, str) require.Empty(t, str)
}) })
} }
@ -117,7 +117,7 @@ func TestGetLicenseFileFromDisk(t *testing.T) {
fileBytes := GetLicenseFileFromDisk(f.Name()) fileBytes := GetLicenseFileFromDisk(f.Name())
require.NotEmpty(t, fileBytes, "should have read the file") require.NotEmpty(t, fileBytes, "should have read the file")
success, _ := LicenseValidator.ValidateLicense(fileBytes) _, err = LicenseValidator.ValidateLicense(fileBytes)
assert.False(t, success, "should have been an invalid file") assert.Error(t, err, "should have been an invalid file")
}) })
} }

View File

@ -43,24 +43,24 @@ func (_m *LicenseValidatorIface) LicenseFromBytes(licenseBytes []byte) (*model.L
} }
// ValidateLicense provides a mock function with given fields: signed // ValidateLicense provides a mock function with given fields: signed
func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (bool, string) { func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (string, error) {
ret := _m.Called(signed) ret := _m.Called(signed)
var r0 bool var r0 string
var r1 string var r1 error
if rf, ok := ret.Get(0).(func([]byte) (bool, string)); ok { if rf, ok := ret.Get(0).(func([]byte) (string, error)); ok {
return rf(signed) return rf(signed)
} }
if rf, ok := ret.Get(0).(func([]byte) bool); ok { if rf, ok := ret.Get(0).(func([]byte) string); ok {
r0 = rf(signed) r0 = rf(signed)
} else { } else {
r0 = ret.Get(0).(bool) r0 = ret.Get(0).(string)
} }
if rf, ok := ret.Get(1).(func([]byte) string); ok { if rf, ok := ret.Get(1).(func([]byte) error); ok {
r1 = rf(signed) r1 = rf(signed)
} else { } else {
r1 = ret.Get(1).(string) r1 = ret.Error(1)
} }
return r0, r1 return r0, r1