mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
SSO: fix settings merge for SAML fields (#86970)
* fix sso settings merge for saml fields * change func name
This commit is contained in:
parent
125ac18fa3
commit
76d94b35c9
@ -470,7 +470,9 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
|
||||
|
||||
for k, v := range systemSettings {
|
||||
if _, ok := settings[k]; !ok {
|
||||
settings[k] = v
|
||||
if isMergingAllowed(k) {
|
||||
settings[k] = v
|
||||
}
|
||||
} else if isURL(k) && isEmptyString(settings[k]) {
|
||||
// Overwrite all URL settings from the DB containing an empty string with their value
|
||||
// from the system settings. This fixes an issue with empty auth_url, api_url and token_url
|
||||
@ -483,6 +485,20 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
|
||||
return settings
|
||||
}
|
||||
|
||||
// isMergingAllowed returns true if the field provided can be merged from the system settings.
|
||||
// It won't allow SAML fields that are part of a group of settings to be merged from system settings
|
||||
// because the DB settings already contain one valid setting from each group.
|
||||
func isMergingAllowed(fieldName string) bool {
|
||||
forbiddenMergePatterns := []string{"certificate", "private_key", "idp_metadata"}
|
||||
|
||||
for _, v := range forbiddenMergePatterns {
|
||||
if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// mergeSecrets returns a new map with the current value for secrets that have not been updated
|
||||
func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) {
|
||||
settingsWithSecrets := map[string]any{}
|
||||
|
@ -40,13 +40,15 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(env testEnv)
|
||||
want *models.SSOSettings
|
||||
wantErr bool
|
||||
name string
|
||||
provider string
|
||||
setup func(env testEnv)
|
||||
want *models.SSOSettings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "should return successfully",
|
||||
name: "should return successfully",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
@ -72,13 +74,15 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error if store returns an error different than not found",
|
||||
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
name: "should return error if store returns an error different than not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "should fallback to the system settings if store returns not found",
|
||||
name: "should fallback to the system settings if store returns not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -99,7 +103,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error if the fallback strategy was not found",
|
||||
name: "should return error if the fallback strategy was not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = false
|
||||
@ -108,7 +113,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "should return error if fallback strategy returns error",
|
||||
name: "should return error if fallback strategy returns error",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -118,7 +124,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "should decrypt secrets if data is coming from store",
|
||||
name: "should decrypt secrets if data is coming from store",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
@ -152,7 +159,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should not decrypt secrets if data is coming from the fallback strategy",
|
||||
name: "should not decrypt secrets if data is coming from the fallback strategy",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -176,7 +184,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return an error if the data in the store is invalid",
|
||||
name: "should return an error if the data in the store is invalid",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
@ -196,7 +205,8 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "correctly merge the DB and system settings",
|
||||
name: "correctly merge URLs from the DB and system settings",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
@ -231,6 +241,45 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "correctly merge group of settings for SAML",
|
||||
provider: "saml",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "saml",
|
||||
Settings: map[string]any{
|
||||
"certificate": base64.RawStdEncoding.EncodeToString([]byte("valid-certificate")),
|
||||
"private_key_path": base64.RawStdEncoding.EncodeToString([]byte("path/to/private/key")),
|
||||
"idp_metadata_url": "https://idp-metadata.com",
|
||||
},
|
||||
Source: models.DB,
|
||||
}
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{
|
||||
"saml": {
|
||||
"name": "test-settings",
|
||||
"certificate_path": "path/to/certificate",
|
||||
"private_key": "this-is-a-valid-private-key",
|
||||
"idp_metadata_path": "path/to/metadata",
|
||||
"max_issue_delay": "1h",
|
||||
},
|
||||
}
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("valid-certificate"), mock.Anything).Return([]byte("decrypted-valid-certificate"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("path/to/private/key"), mock.Anything).Return([]byte("decrypted/path/to/private/key"), nil).Once()
|
||||
},
|
||||
want: &models.SSOSettings{
|
||||
Provider: "saml",
|
||||
Settings: map[string]any{
|
||||
"name": "test-settings",
|
||||
"certificate": "decrypted-valid-certificate",
|
||||
"private_key_path": "decrypted/path/to/private/key",
|
||||
"idp_metadata_url": "https://idp-metadata.com",
|
||||
"max_issue_delay": "1h",
|
||||
},
|
||||
Source: models.DB,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -241,12 +290,12 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, true, false, true)
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
}
|
||||
|
||||
actual, err := env.service.GetForProvider(context.Background(), "github")
|
||||
actual, err := env.service.GetForProvider(context.Background(), tc.provider)
|
||||
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
|
Loading…
Reference in New Issue
Block a user