[MM-56653] Improve license loading errors (#26050)

This commit is contained in:
Ben Schumacher 2024-04-05 16:59:19 +02:00 committed by GitHub
parent 1a9355b2eb
commit 71e26b8df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 84 deletions

View File

@ -109,7 +109,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
@ -144,7 +144,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, err := json.Marshal(license)
require.NoError(t, err)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseBytes))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseBytes), nil)
utils.LicenseValidator = &mockLicenseValidator
resp, err := th.SystemAdminClient.UploadLicenseFile(context.Background(), []byte(""))
@ -177,7 +177,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)
utils.LicenseValidator = &mockLicenseValidator
@ -275,7 +275,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -306,7 +306,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -344,7 +344,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -405,7 +405,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
@ -435,7 +435,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()

View File

@ -129,7 +129,7 @@ func (s *Server) SetLicense(license *model.License) bool {
return s.platform.SetLicense(license)
}
func (s *Server) ValidateAndSetLicenseBytes(b []byte) bool {
func (s *Server) ValidateAndSetLicenseBytes(b []byte) error {
return s.platform.ValidateAndSetLicenseBytes(b)
}

View File

@ -54,9 +54,9 @@ func (ps *PlatformService) LoadLicense() {
// ENV var overrides all other sources of license.
licenseStr := os.Getenv(LicenseEnv)
if licenseStr != "" {
license, err := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr))
if err != nil {
ps.logger.Error("Failed to read license set in environment.", mlog.Err(err))
license, appErr := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr))
if appErr != nil {
ps.logger.Error("Failed to read license set in environment.", mlog.Err(appErr))
return
}
@ -74,7 +74,9 @@ func (ps *PlatformService) LoadLicense() {
}
}
if ps.ValidateAndSetLicenseBytes([]byte(licenseStr)) {
if err := ps.ValidateAndSetLicenseBytes([]byte(licenseStr)); err != nil {
ps.logger.Info("License key from ENV is invalid.", mlog.Err(err))
} else {
ps.logger.Info("License key from ENV is valid, unlocking enterprise features.")
}
return
@ -88,9 +90,10 @@ func (ps *PlatformService) LoadLicense() {
if !model.IsValidId(licenseId) {
// Lets attempt to load the file from disk since it was missing from the DB
license, licenseBytes := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation)
if license != nil {
license, licenseBytes, err := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation)
if err != nil {
ps.logger.Warn("Failed to get license from disk", mlog.Err(err))
} else {
if _, err := ps.SaveLicense(licenseBytes); err != nil {
ps.logger.Error("Failed to save license key loaded from disk.", mlog.Err(err))
} else {
@ -101,19 +104,23 @@ func (ps *PlatformService) LoadLicense() {
record, nErr := ps.Store.License().Get(sqlstore.RequestContextWithMaster(c), licenseId)
if nErr != nil {
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.logger.Warn("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.SetLicense(nil)
return
}
ps.ValidateAndSetLicenseBytes([]byte(record.Bytes))
ps.logger.Info("License key valid unlocking enterprise features.")
err := ps.ValidateAndSetLicenseBytes([]byte(record.Bytes))
if err != nil {
ps.logger.Info("License key is invalid.")
}
ps.logger.Info("License key is valid, unlocking enterprise features.")
}
func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := utils.LicenseValidator.ValidateLicense(licenseBytes)
if !success {
return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest)
licenseStr, err := utils.LicenseValidator.ValidateLicense(licenseBytes)
if err != nil {
return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
}
var license model.License
@ -231,19 +238,19 @@ func (ps *PlatformService) SetLicense(license *model.License) bool {
return false
}
func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) bool {
if success, licenseStr := utils.LicenseValidator.ValidateLicense(b); success {
var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
ps.logger.Warn("Failed to decode license from JSON", mlog.Err(jsonErr))
return false
}
ps.SetLicense(&license)
return true
func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) error {
licenseStr, err := utils.LicenseValidator.ValidateLicense(b)
if err != nil {
return errors.Wrap(err, "Failed to decode license from JSON")
}
ps.logger.Warn("No valid enterprise license found")
return false
var license model.License
if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return errors.Wrap(err, "Failed to decode license from JSON")
}
ps.SetLicense(&license)
return nil
}
func (ps *PlatformService) SetClientLicense(m map[string]string) {

View File

@ -11,6 +11,7 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"os"
@ -32,33 +33,32 @@ func init() {
type LicenseValidatorIface interface {
LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError)
ValidateLicense(signed []byte) (bool, string)
ValidateLicense(signed []byte) (string, error)
}
type LicenseValidatorImpl struct {
}
func (l *LicenseValidatorImpl) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := l.ValidateLicense(licenseBytes)
if !success {
return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest)
licenseStr, err := l.ValidateLicense(licenseBytes)
if err != nil {
return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
}
var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(err)
}
return &license, nil
}
func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (string, error) {
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(signed)))
_, err := base64.StdEncoding.Decode(decoded, signed)
if err != nil {
mlog.Error("Encountered error decoding license", mlog.Err(err))
return false, ""
return "", fmt.Errorf("encountered error decoding license: %w", err)
}
// remove null terminator
@ -67,8 +67,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
}
if len(decoded) <= 256 {
mlog.Error("Signed license not long enough")
return false, ""
return "", fmt.Errorf("Signed license not long enough")
}
plaintext := decoded[:len(decoded)-256]
@ -85,8 +84,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
public, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
mlog.Error("Encountered error signing license", mlog.Err(err))
return false, ""
return "", fmt.Errorf("Encountered error signing license: %w", err)
}
rsaPublic := public.(*rsa.PublicKey)
@ -97,37 +95,34 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
err = rsa.VerifyPKCS1v15(rsaPublic, crypto.SHA512, d, signature)
if err != nil {
mlog.Error("Invalid signature", mlog.Err(err))
return false, ""
return "", fmt.Errorf("Invalid signature: %w", err)
}
return true, string(plaintext)
return string(plaintext), nil
}
func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte) {
func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte, error) {
fileName := GetLicenseFileLocation(location)
mlog.Info("License key has not been uploaded. Loading license key from disk.", mlog.String("filename", fileName))
if _, err := os.Stat(fileName); err != nil {
mlog.Debug("We could not find the license key in the database or on disk at", mlog.String("filename", fileName))
return nil, nil
return nil, nil, fmt.Errorf("We could not find the license key on disk at %s: %w", fileName, err)
}
mlog.Info("License key has not been uploaded. Loading license key from disk at", mlog.String("filename", fileName))
licenseBytes := GetLicenseFileFromDisk(fileName)
success, licenseStr := LicenseValidator.ValidateLicense(licenseBytes)
if !success {
mlog.Error("Found license key at %v but it appears to be invalid.", mlog.String("filename", fileName))
return nil, nil
licenseStr, err := LicenseValidator.ValidateLicense(licenseBytes)
if err != nil {
return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
}
var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
mlog.Error("Failed to decode license from JSON", mlog.Err(jsonErr))
return nil, nil
return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
}
return &license, licenseBytes
return &license, licenseBytes, nil
}
func GetLicenseFileFromDisk(fileName string) []byte {

View File

@ -19,12 +19,12 @@ var validTestLicense = []byte("eyJpZCI6InpvZ3c2NW44Z2lmajVkbHJoYThtYnUxcGl3Iiwia
func TestValidateLicense(t *testing.T) {
t.Run("should fail with junk data", func(t *testing.T) {
b1 := []byte("junk")
ok, _ := LicenseValidator.ValidateLicense(b1)
require.False(t, ok, "should have failed - bad license")
_, err := LicenseValidator.ValidateLicense(b1)
require.Error(t, err, "should have failed - bad license")
b2 := []byte("junkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunk")
ok, _ = LicenseValidator.ValidateLicense(b2)
require.False(t, ok, "should have failed - bad license")
_, err = LicenseValidator.ValidateLicense(b2)
require.Error(t, err, "should have failed - bad license")
})
t.Run("should not panic on shorter than expected input", func(t *testing.T) {
@ -42,8 +42,8 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close()
require.NoError(t, err)
ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.Error(t, err)
require.Empty(t, str)
})
@ -61,8 +61,8 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close()
require.NoError(t, err)
ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.Error(t, err)
require.Empty(t, str)
})
@ -70,8 +70,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(nil)
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(nil)
require.Error(t, err)
require.Empty(t, str)
})
@ -79,8 +79,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(validTestLicense)
require.True(t, ok)
str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.NoError(t, err)
require.NotEmpty(t, str)
})
@ -88,8 +88,8 @@ func TestValidateLicense(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentProduction)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")
ok, str := LicenseValidator.ValidateLicense(validTestLicense)
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.Error(t, err)
require.Empty(t, str)
})
}
@ -117,7 +117,7 @@ func TestGetLicenseFileFromDisk(t *testing.T) {
fileBytes := GetLicenseFileFromDisk(f.Name())
require.NotEmpty(t, fileBytes, "should have read the file")
success, _ := LicenseValidator.ValidateLicense(fileBytes)
assert.False(t, success, "should have been an invalid file")
_, err = LicenseValidator.ValidateLicense(fileBytes)
assert.Error(t, err, "should have been an invalid file")
})
}

View File

@ -43,24 +43,24 @@ func (_m *LicenseValidatorIface) LicenseFromBytes(licenseBytes []byte) (*model.L
}
// ValidateLicense provides a mock function with given fields: signed
func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (bool, string) {
func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (string, error) {
ret := _m.Called(signed)
var r0 bool
var r1 string
if rf, ok := ret.Get(0).(func([]byte) (bool, string)); ok {
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func([]byte) (string, error)); ok {
return rf(signed)
}
if rf, ok := ret.Get(0).(func([]byte) bool); ok {
if rf, ok := ret.Get(0).(func([]byte) string); ok {
r0 = rf(signed)
} else {
r0 = ret.Get(0).(bool)
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func([]byte) string); ok {
if rf, ok := ret.Get(1).(func([]byte) error); ok {
r1 = rf(signed)
} else {
r1 = ret.Get(1).(string)
r1 = ret.Error(1)
}
return r0, r1