From 76d94b35c98dbe954033fa716a530997b1a52af8 Mon Sep 17 00:00:00 2001 From: Mihai Doarna Date: Tue, 30 Apr 2024 15:10:27 +0300 Subject: [PATCH] SSO: fix settings merge for SAML fields (#86970) * fix sso settings merge for saml fields * change func name --- .../ssosettings/ssosettingsimpl/service.go | 18 +++- .../ssosettingsimpl/service_test.go | 85 +++++++++++++++---- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go index f1f2f3e97bc..1bb029821ad 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -470,7 +470,9 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any for k, v := range systemSettings { if _, ok := settings[k]; !ok { - settings[k] = v + if isMergingAllowed(k) { + settings[k] = v + } } else if isURL(k) && isEmptyString(settings[k]) { // Overwrite all URL settings from the DB containing an empty string with their value // from the system settings. This fixes an issue with empty auth_url, api_url and token_url @@ -483,6 +485,20 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any return settings } +// isMergingAllowed returns true if the field provided can be merged from the system settings. +// It won't allow SAML fields that are part of a group of settings to be merged from system settings +// because the DB settings already contain one valid setting from each group. +func isMergingAllowed(fieldName string) bool { + forbiddenMergePatterns := []string{"certificate", "private_key", "idp_metadata"} + + for _, v := range forbiddenMergePatterns { + if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) { + return false + } + } + return true +} + // mergeSecrets returns a new map with the current value for secrets that have not been updated func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) { settingsWithSecrets := map[string]any{} diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go index 7a9bd0919ad..7f9cd0aba05 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service_test.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -40,13 +40,15 @@ func TestService_GetForProvider(t *testing.T) { t.Parallel() testCases := []struct { - name string - setup func(env testEnv) - want *models.SSOSettings - wantErr bool + name string + provider string + setup func(env testEnv) + want *models.SSOSettings + wantErr bool }{ { - name: "should return successfully", + name: "should return successfully", + provider: "github", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", @@ -72,13 +74,15 @@ func TestService_GetForProvider(t *testing.T) { wantErr: false, }, { - name: "should return error if store returns an error different than not found", - setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, - want: nil, - wantErr: true, + name: "should return error if store returns an error different than not found", + provider: "github", + setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, + want: nil, + wantErr: true, }, { - name: "should fallback to the system settings if store returns not found", + name: "should fallback to the system settings if store returns not found", + provider: "github", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true @@ -99,7 +103,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: false, }, { - name: "should return error if the fallback strategy was not found", + name: "should return error if the fallback strategy was not found", + provider: "github", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = false @@ -108,7 +113,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: true, }, { - name: "should return error if fallback strategy returns error", + name: "should return error if fallback strategy returns error", + provider: "github", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true @@ -118,7 +124,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: true, }, { - name: "should decrypt secrets if data is coming from store", + name: "should decrypt secrets if data is coming from store", + provider: "github", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", @@ -152,7 +159,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: false, }, { - name: "should not decrypt secrets if data is coming from the fallback strategy", + name: "should not decrypt secrets if data is coming from the fallback strategy", + provider: "github", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true @@ -176,7 +184,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: false, }, { - name: "should return an error if the data in the store is invalid", + name: "should return an error if the data in the store is invalid", + provider: "github", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", @@ -196,7 +205,8 @@ func TestService_GetForProvider(t *testing.T) { wantErr: true, }, { - name: "correctly merge the DB and system settings", + name: "correctly merge URLs from the DB and system settings", + provider: "github", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", @@ -231,6 +241,45 @@ func TestService_GetForProvider(t *testing.T) { }, wantErr: false, }, + { + name: "correctly merge group of settings for SAML", + provider: "saml", + setup: func(env testEnv) { + env.store.ExpectedSSOSetting = &models.SSOSettings{ + Provider: "saml", + Settings: map[string]any{ + "certificate": base64.RawStdEncoding.EncodeToString([]byte("valid-certificate")), + "private_key_path": base64.RawStdEncoding.EncodeToString([]byte("path/to/private/key")), + "idp_metadata_url": "https://idp-metadata.com", + }, + Source: models.DB, + } + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ + "saml": { + "name": "test-settings", + "certificate_path": "path/to/certificate", + "private_key": "this-is-a-valid-private-key", + "idp_metadata_path": "path/to/metadata", + "max_issue_delay": "1h", + }, + } + env.secrets.On("Decrypt", mock.Anything, []byte("valid-certificate"), mock.Anything).Return([]byte("decrypted-valid-certificate"), nil).Once() + env.secrets.On("Decrypt", mock.Anything, []byte("path/to/private/key"), mock.Anything).Return([]byte("decrypted/path/to/private/key"), nil).Once() + }, + want: &models.SSOSettings{ + Provider: "saml", + Settings: map[string]any{ + "name": "test-settings", + "certificate": "decrypted-valid-certificate", + "private_key_path": "decrypted/path/to/private/key", + "idp_metadata_url": "https://idp-metadata.com", + "max_issue_delay": "1h", + }, + Source: models.DB, + }, + wantErr: false, + }, } for _, tc := range testCases { @@ -241,12 +290,12 @@ func TestService_GetForProvider(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - env := setupTestEnv(t, false, false, false) + env := setupTestEnv(t, true, false, true) if tc.setup != nil { tc.setup(env) } - actual, err := env.service.GetForProvider(context.Background(), "github") + actual, err := env.service.GetForProvider(context.Background(), tc.provider) if tc.wantErr { require.Error(t, err)