mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
SSO: fix settings merge for SAML fields (#86970)
* fix sso settings merge for saml fields * change func name
This commit is contained in:
parent
125ac18fa3
commit
76d94b35c9
@ -470,7 +470,9 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
|
|||||||
|
|
||||||
for k, v := range systemSettings {
|
for k, v := range systemSettings {
|
||||||
if _, ok := settings[k]; !ok {
|
if _, ok := settings[k]; !ok {
|
||||||
settings[k] = v
|
if isMergingAllowed(k) {
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
} else if isURL(k) && isEmptyString(settings[k]) {
|
} else if isURL(k) && isEmptyString(settings[k]) {
|
||||||
// Overwrite all URL settings from the DB containing an empty string with their value
|
// Overwrite all URL settings from the DB containing an empty string with their value
|
||||||
// from the system settings. This fixes an issue with empty auth_url, api_url and token_url
|
// from the system settings. This fixes an issue with empty auth_url, api_url and token_url
|
||||||
@ -483,6 +485,20 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
|
|||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMergingAllowed returns true if the field provided can be merged from the system settings.
|
||||||
|
// It won't allow SAML fields that are part of a group of settings to be merged from system settings
|
||||||
|
// because the DB settings already contain one valid setting from each group.
|
||||||
|
func isMergingAllowed(fieldName string) bool {
|
||||||
|
forbiddenMergePatterns := []string{"certificate", "private_key", "idp_metadata"}
|
||||||
|
|
||||||
|
for _, v := range forbiddenMergePatterns {
|
||||||
|
if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// mergeSecrets returns a new map with the current value for secrets that have not been updated
|
// mergeSecrets returns a new map with the current value for secrets that have not been updated
|
||||||
func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) {
|
func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) {
|
||||||
settingsWithSecrets := map[string]any{}
|
settingsWithSecrets := map[string]any{}
|
||||||
|
@ -40,13 +40,15 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
setup func(env testEnv)
|
provider string
|
||||||
want *models.SSOSettings
|
setup func(env testEnv)
|
||||||
wantErr bool
|
want *models.SSOSettings
|
||||||
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "should return successfully",
|
name: "should return successfully",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||||
Provider: "github",
|
Provider: "github",
|
||||||
@ -72,13 +74,15 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return error if store returns an error different than not found",
|
name: "should return error if store returns an error different than not found",
|
||||||
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
provider: "github",
|
||||||
want: nil,
|
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
||||||
wantErr: true,
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should fallback to the system settings if store returns not found",
|
name: "should fallback to the system settings if store returns not found",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||||
env.fallbackStrategy.ExpectedIsMatch = true
|
env.fallbackStrategy.ExpectedIsMatch = true
|
||||||
@ -99,7 +103,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return error if the fallback strategy was not found",
|
name: "should return error if the fallback strategy was not found",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||||
env.fallbackStrategy.ExpectedIsMatch = false
|
env.fallbackStrategy.ExpectedIsMatch = false
|
||||||
@ -108,7 +113,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return error if fallback strategy returns error",
|
name: "should return error if fallback strategy returns error",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||||
env.fallbackStrategy.ExpectedIsMatch = true
|
env.fallbackStrategy.ExpectedIsMatch = true
|
||||||
@ -118,7 +124,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should decrypt secrets if data is coming from store",
|
name: "should decrypt secrets if data is coming from store",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||||
Provider: "github",
|
Provider: "github",
|
||||||
@ -152,7 +159,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should not decrypt secrets if data is coming from the fallback strategy",
|
name: "should not decrypt secrets if data is coming from the fallback strategy",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||||
env.fallbackStrategy.ExpectedIsMatch = true
|
env.fallbackStrategy.ExpectedIsMatch = true
|
||||||
@ -176,7 +184,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should return an error if the data in the store is invalid",
|
name: "should return an error if the data in the store is invalid",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||||
Provider: "github",
|
Provider: "github",
|
||||||
@ -196,7 +205,8 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "correctly merge the DB and system settings",
|
name: "correctly merge URLs from the DB and system settings",
|
||||||
|
provider: "github",
|
||||||
setup: func(env testEnv) {
|
setup: func(env testEnv) {
|
||||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||||
Provider: "github",
|
Provider: "github",
|
||||||
@ -231,6 +241,45 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "correctly merge group of settings for SAML",
|
||||||
|
provider: "saml",
|
||||||
|
setup: func(env testEnv) {
|
||||||
|
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||||
|
Provider: "saml",
|
||||||
|
Settings: map[string]any{
|
||||||
|
"certificate": base64.RawStdEncoding.EncodeToString([]byte("valid-certificate")),
|
||||||
|
"private_key_path": base64.RawStdEncoding.EncodeToString([]byte("path/to/private/key")),
|
||||||
|
"idp_metadata_url": "https://idp-metadata.com",
|
||||||
|
},
|
||||||
|
Source: models.DB,
|
||||||
|
}
|
||||||
|
env.fallbackStrategy.ExpectedIsMatch = true
|
||||||
|
env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{
|
||||||
|
"saml": {
|
||||||
|
"name": "test-settings",
|
||||||
|
"certificate_path": "path/to/certificate",
|
||||||
|
"private_key": "this-is-a-valid-private-key",
|
||||||
|
"idp_metadata_path": "path/to/metadata",
|
||||||
|
"max_issue_delay": "1h",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
env.secrets.On("Decrypt", mock.Anything, []byte("valid-certificate"), mock.Anything).Return([]byte("decrypted-valid-certificate"), nil).Once()
|
||||||
|
env.secrets.On("Decrypt", mock.Anything, []byte("path/to/private/key"), mock.Anything).Return([]byte("decrypted/path/to/private/key"), nil).Once()
|
||||||
|
},
|
||||||
|
want: &models.SSOSettings{
|
||||||
|
Provider: "saml",
|
||||||
|
Settings: map[string]any{
|
||||||
|
"name": "test-settings",
|
||||||
|
"certificate": "decrypted-valid-certificate",
|
||||||
|
"private_key_path": "decrypted/path/to/private/key",
|
||||||
|
"idp_metadata_url": "https://idp-metadata.com",
|
||||||
|
"max_issue_delay": "1h",
|
||||||
|
},
|
||||||
|
Source: models.DB,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@ -241,12 +290,12 @@ func TestService_GetForProvider(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
env := setupTestEnv(t, false, false, false)
|
env := setupTestEnv(t, true, false, true)
|
||||||
if tc.setup != nil {
|
if tc.setup != nil {
|
||||||
tc.setup(env)
|
tc.setup(env)
|
||||||
}
|
}
|
||||||
|
|
||||||
actual, err := env.service.GetForProvider(context.Background(), "github")
|
actual, err := env.service.GetForProvider(context.Background(), tc.provider)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
Loading…
Reference in New Issue
Block a user