mirror of
https://github.com/grafana/grafana.git
synced 2025-01-04 13:17:16 -06:00
SSO: Encrypt and decrypt secrets for LDAP settings (#89470)
encrypt/decrypt secrets for LDAP
This commit is contained in:
parent
f1968bbcbb
commit
4306d52353
@ -62,6 +62,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
||||
|
||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsLDAP) {
|
||||
providersList = append(providersList, social.LDAPProviderName)
|
||||
configurableProviders[social.LDAPProviderName] = true
|
||||
}
|
||||
|
||||
if licensing.FeatureEnabled(social.SAMLProviderName) {
|
||||
@ -320,21 +321,23 @@ func (s *Service) getFallbackStrategyFor(provider string) (ssosettings.FallbackS
|
||||
}
|
||||
|
||||
func (s *Service) encryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
for k, v := range settings {
|
||||
if IsSecretField(k) && v != "" {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
||||
}
|
||||
result := deepCopyMap(settings)
|
||||
configs := getConfigMaps(result)
|
||||
|
||||
encryptedSecret, err := s.secrets.Encrypt(ctx, []byte(strValue), secrets.WithoutScope())
|
||||
if err != nil {
|
||||
return result, err
|
||||
for _, config := range configs {
|
||||
for k, v := range config {
|
||||
if IsSecretField(k) && v != "" {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
||||
}
|
||||
|
||||
encryptedSecret, err := s.secrets.Encrypt(ctx, []byte(strValue), secrets.WithoutScope())
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
config[k] = base64.RawStdEncoding.EncodeToString(encryptedSecret)
|
||||
}
|
||||
result[k] = base64.RawStdEncoding.EncodeToString(encryptedSecret)
|
||||
} else {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
@ -411,29 +414,34 @@ func (s *Service) mergeSSOSettings(dbSettings, systemSettings *models.SSOSetting
|
||||
}
|
||||
|
||||
func (s *Service) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
|
||||
for k, v := range settings {
|
||||
if IsSecretField(k) && v != "" {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
s.logger.Error("Failed to parse secret value, it is not a string", "key", k)
|
||||
return nil, fmt.Errorf("secret value is not a string")
|
||||
}
|
||||
configs := getConfigMaps(settings)
|
||||
|
||||
decoded, err := base64.RawStdEncoding.DecodeString(strValue)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to decode secret string", "err", err, "value")
|
||||
return nil, err
|
||||
}
|
||||
for _, config := range configs {
|
||||
for k, v := range config {
|
||||
if IsSecretField(k) && v != "" {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
s.logger.Error("Failed to parse secret value, it is not a string", "key", k)
|
||||
return nil, fmt.Errorf("secret value is not a string")
|
||||
}
|
||||
|
||||
decrypted, err := s.secrets.Decrypt(ctx, decoded)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to decrypt secret", "err", err)
|
||||
return nil, err
|
||||
}
|
||||
decoded, err := base64.RawStdEncoding.DecodeString(strValue)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to decode secret string", "err", err, "value")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings[k] = string(decrypted)
|
||||
decrypted, err := s.secrets.Decrypt(ctx, decoded)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to decrypt secret", "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config[k] = string(decrypted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
@ -445,18 +453,39 @@ func (s *Service) isProviderConfigurable(provider string) bool {
|
||||
// removeSecrets removes all the secrets from the map and replaces them with a redacted password
|
||||
// and returns a new map
|
||||
func removeSecrets(settings map[string]any) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for k, v := range settings {
|
||||
val, ok := v.(string)
|
||||
if ok && val != "" && IsSecretField(k) {
|
||||
result[k] = setting.RedactedPassword
|
||||
continue
|
||||
result := deepCopyMap(settings)
|
||||
configs := getConfigMaps(result)
|
||||
|
||||
for _, config := range configs {
|
||||
for k, v := range config {
|
||||
val, ok := v.(string)
|
||||
if ok && val != "" && IsSecretField(k) {
|
||||
config[k] = setting.RedactedPassword
|
||||
}
|
||||
}
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getConfigMaps returns a list of maps that may contain secrets
|
||||
func getConfigMaps(settings map[string]any) []map[string]any {
|
||||
// always include the main settings map
|
||||
result := []map[string]any{settings}
|
||||
|
||||
// for LDAP include settings for each server
|
||||
if config, ok := settings["config"].(map[string]any); ok {
|
||||
if servers, ok := config["servers"].([]any); ok {
|
||||
for _, server := range servers {
|
||||
if serverSettings, ok := server.(map[string]any); ok {
|
||||
result = append(result, serverSettings)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeSettings merges two maps in a way that the values from the first map are preserved
|
||||
// and the values from the second map are added only if they don't exist in the first map
|
||||
// or if they contain empty URLs.
|
||||
@ -500,23 +529,25 @@ func isMergingAllowed(fieldName string) bool {
|
||||
|
||||
// mergeSecrets returns a new map with the current value for secrets that have not been updated
|
||||
func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) {
|
||||
settingsWithSecrets := map[string]any{}
|
||||
for k, v := range settings {
|
||||
if IsSecretField(k) {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret value is not a string")
|
||||
}
|
||||
settingsWithSecrets := deepCopyMap(settings)
|
||||
newConfigs := getConfigMaps(settingsWithSecrets)
|
||||
storedConfigs := getConfigMaps(storedSettings)
|
||||
|
||||
if isNewSecretValue(strValue) {
|
||||
settingsWithSecrets[k] = strValue // use the new value
|
||||
continue
|
||||
for i, config := range newConfigs {
|
||||
for k, v := range config {
|
||||
if IsSecretField(k) {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret value is not a string")
|
||||
}
|
||||
|
||||
if !isNewSecretValue(strValue) && len(storedConfigs) > i {
|
||||
config[k] = storedConfigs[i][k] // use the currently stored value
|
||||
}
|
||||
}
|
||||
settingsWithSecrets[k] = storedSettings[k] // keep the currently stored value
|
||||
} else {
|
||||
settingsWithSecrets[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settingsWithSecrets, nil
|
||||
}
|
||||
|
||||
@ -532,7 +563,7 @@ func overrideMaps(maps ...map[string]any) map[string]any {
|
||||
|
||||
// IsSecretField returns true if the SSO settings field provided is a secret
|
||||
func IsSecretField(fieldName string) bool {
|
||||
secretFieldPatterns := []string{"secret", "private", "certificate"}
|
||||
secretFieldPatterns := []string{"secret", "private", "certificate", "password", "client_key"}
|
||||
|
||||
for _, v := range secretFieldPatterns {
|
||||
if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) {
|
||||
@ -554,3 +585,37 @@ func isEmptyString(val any) bool {
|
||||
func isNewSecretValue(value string) bool {
|
||||
return value != setting.RedactedPassword
|
||||
}
|
||||
|
||||
func deepCopyMap(settings map[string]any) map[string]any {
|
||||
newSettings := make(map[string]any)
|
||||
|
||||
for key, value := range settings {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
newSettings[key] = deepCopyMap(v)
|
||||
case []any:
|
||||
newSettings[key] = deepCopySlice(v)
|
||||
default:
|
||||
newSettings[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return newSettings
|
||||
}
|
||||
|
||||
func deepCopySlice(s []any) []any {
|
||||
newSlice := make([]any, len(s))
|
||||
|
||||
for i, value := range s {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
newSlice[i] = deepCopyMap(v)
|
||||
case []any:
|
||||
newSlice[i] = deepCopySlice(v)
|
||||
default:
|
||||
newSlice[i] = value
|
||||
}
|
||||
}
|
||||
|
||||
return newSlice
|
||||
}
|
||||
|
@ -158,6 +158,62 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should decrypt secrets for LDAP if data is coming from store",
|
||||
provider: "ldap",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "ldap",
|
||||
Settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": base64.RawStdEncoding.EncodeToString([]byte("bind_password_1")),
|
||||
"client_key": base64.RawStdEncoding.EncodeToString([]byte("client_key_1")),
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": base64.RawStdEncoding.EncodeToString([]byte("bind_password_2")),
|
||||
"client_key": base64.RawStdEncoding.EncodeToString([]byte("client_key_2")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Source: models.DB,
|
||||
}
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{}
|
||||
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("bind_password_1"), mock.Anything).Return([]byte("decrypted-bind-password-1"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("client_key_1"), mock.Anything).Return([]byte("decrypted-client-key-1"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("bind_password_2"), mock.Anything).Return([]byte("decrypted-bind-password-2"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("client_key_2"), mock.Anything).Return([]byte("decrypted-client-key-2"), nil).Once()
|
||||
},
|
||||
want: &models.SSOSettings{
|
||||
Provider: "ldap",
|
||||
Settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": "decrypted-bind-password-1",
|
||||
"client_key": "decrypted-client-key-1",
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": "decrypted-bind-password-2",
|
||||
"client_key": "decrypted-client-key-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Source: models.DB,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should not decrypt secrets if data is coming from the fallback strategy",
|
||||
provider: "github",
|
||||
@ -290,7 +346,7 @@ func TestService_GetForProvider(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, true, false, true)
|
||||
env := setupTestEnv(t, true, false, true, true)
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
}
|
||||
@ -314,13 +370,15 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(env testEnv)
|
||||
want *models.SSOSettings
|
||||
wantErr bool
|
||||
name string
|
||||
provider string
|
||||
setup func(env testEnv)
|
||||
want *models.SSOSettings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "should return successfully and redact secrets",
|
||||
name: "should return successfully and redact secrets",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
@ -347,13 +405,67 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error if store returns an error different than not found",
|
||||
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
name: "should return successfully and redact secrets for LDAP",
|
||||
provider: "ldap",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "ldap",
|
||||
Settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": base64.RawStdEncoding.EncodeToString([]byte("bind_password_1")),
|
||||
"client_key": base64.RawStdEncoding.EncodeToString([]byte("client_key_1")),
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": base64.RawStdEncoding.EncodeToString([]byte("bind_password_2")),
|
||||
"client_key": base64.RawStdEncoding.EncodeToString([]byte("client_key_2")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Source: models.DB,
|
||||
}
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("bind_password_1"), mock.Anything).Return([]byte("decrypted-bind-password-1"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("client_key_1"), mock.Anything).Return([]byte("decrypted-client-key-1"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("bind_password_2"), mock.Anything).Return([]byte("decrypted-bind-password-2"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("client_key_2"), mock.Anything).Return([]byte("decrypted-client-key-2"), nil).Once()
|
||||
},
|
||||
want: &models.SSOSettings{
|
||||
Provider: "ldap",
|
||||
Settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": "*********",
|
||||
"client_key": "*********",
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": "*********",
|
||||
"client_key": "*********",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should fallback to strategy if store returns not found",
|
||||
name: "should return error if store returns an error different than not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "should fallback to strategy if store returns not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -371,7 +483,8 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error if the fallback strategy was not found",
|
||||
name: "should return error if the fallback strategy was not found",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = false
|
||||
@ -380,7 +493,8 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "should return error if fallback strategy returns error",
|
||||
name: "should return error if fallback strategy returns error",
|
||||
provider: "github",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -399,12 +513,12 @@ func TestService_GetForProviderWithRedactedSecrets(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, true)
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
}
|
||||
|
||||
actual, err := env.service.GetForProviderWithRedactedSecrets(context.Background(), "github")
|
||||
actual, err := env.service.GetForProviderWithRedactedSecrets(context.Background(), tc.provider)
|
||||
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
@ -550,7 +664,7 @@ func TestService_List(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
}
|
||||
@ -852,7 +966,7 @@ func TestService_ListWithRedactedSecrets(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
}
|
||||
@ -876,7 +990,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("successfully upsert SSO settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -936,10 +1050,80 @@ func TestService_Upsert(t *testing.T) {
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
|
||||
t.Run("successfully upsert SSO settings for LDAP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false, true)
|
||||
|
||||
provider := social.LDAPProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": "bind_password_1",
|
||||
"client_key": "client_key_1",
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": "bind_password_2",
|
||||
"client_key": "client_key_2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings, mock.Anything, mock.Anything).Return(nil)
|
||||
reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool {
|
||||
defer wg.Done()
|
||||
return settings.Provider == provider &&
|
||||
settings.ID == "someid" &&
|
||||
maps.Equal(settings.Settings["config"].(map[string]any)["servers"].([]any)[0].(map[string]any), map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": "bind_password_1",
|
||||
"client_key": "client_key_1",
|
||||
}) && maps.Equal(settings.Settings["config"].(map[string]any)["servers"].([]any)[1].(map[string]any), map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": "bind_password_2",
|
||||
"client_key": "client_key_2",
|
||||
})
|
||||
})).Return(nil).Once()
|
||||
env.reloadables[provider] = reloadable
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte("bind_password_1"), mock.Anything).Return([]byte("encrypted-bind-password-1"), nil).Once()
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte("bind_password_2"), mock.Anything).Return([]byte("encrypted-bind-password-2"), nil).Once()
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte("client_key_1"), mock.Anything).Return([]byte("encrypted-client-key-1"), nil).Once()
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte("client_key_2"), mock.Anything).Return([]byte("encrypted-client-key-2"), nil).Once()
|
||||
|
||||
env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
|
||||
currentTime := time.Now()
|
||||
settings.ID = "someid"
|
||||
settings.Created = currentTime
|
||||
settings.Updated = currentTime
|
||||
|
||||
env.store.ActualSSOSettings = *settings
|
||||
return nil
|
||||
}
|
||||
|
||||
err := env.service.Upsert(context.Background(), &settings, &user.SignedInUser{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine first to assert the Reload call
|
||||
wg.Wait()
|
||||
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
|
||||
t.Run("returns error if provider is not configurable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.GrafanaComProviderName
|
||||
settings := &models.SSOSettings{
|
||||
@ -962,7 +1146,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if provider was not found in reloadables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := &models.SSOSettings{
|
||||
@ -986,7 +1170,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if validation fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1010,7 +1194,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
settings := &models.SSOSettings{
|
||||
Provider: social.AzureADProviderName,
|
||||
@ -1031,7 +1215,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if a secret does not have the type string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.OktaProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1054,7 +1238,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.OktaProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1079,7 +1263,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("should not update the current secret if the secret has not been updated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1123,7 +1307,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("run validation with all new and current secrets available in settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1176,7 +1360,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1208,7 +1392,7 @@ func TestService_Upsert(t *testing.T) {
|
||||
t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
@ -1241,7 +1425,7 @@ func TestService_Delete(t *testing.T) {
|
||||
t.Run("successfully delete SSO settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@ -1279,7 +1463,7 @@ func TestService_Delete(t *testing.T) {
|
||||
t.Run("return error if SSO setting was not found for the specified provider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
@ -1295,7 +1479,7 @@ func TestService_Delete(t *testing.T) {
|
||||
t.Run("should not delete the SSO settings if the provider is not configurable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
env.cfg.SSOSettingsConfigurableProviders = map[string]bool{social.AzureADProviderName: true}
|
||||
|
||||
provider := social.GrafanaComProviderName
|
||||
@ -1308,7 +1492,7 @@ func TestService_Delete(t *testing.T) {
|
||||
t.Run("return error when store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
env.store.ExpectedError = errors.New("delete sso settings failed")
|
||||
@ -1321,7 +1505,7 @@ func TestService_Delete(t *testing.T) {
|
||||
t.Run("return successfully when the deletion was successful but reloading the settings fail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
@ -1337,13 +1521,51 @@ func TestService_Delete(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// we might not need this test because it is not testing the public interface
|
||||
// it was added for convenient testing of the internal deep copy and remove secrets
|
||||
func TestRemoveSecrets(t *testing.T) {
|
||||
settings := map[string]any{
|
||||
"enabled": true,
|
||||
"client_secret": "client_secret",
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"host": "192.168.0.1",
|
||||
"bind_password": "bind_password_1",
|
||||
"client_key": "client_key_1",
|
||||
},
|
||||
map[string]any{
|
||||
"host": "192.168.0.2",
|
||||
"bind_password": "bind_password_2",
|
||||
"client_key": "client_key_2",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
copiedSettings := deepCopyMap(settings)
|
||||
copiedSettings["client_secret"] = "client_secret_updated"
|
||||
copiedSettings["config"].(map[string]any)["servers"].([]any)[0].(map[string]any)["bind_password"] = "bind_password_1_updated"
|
||||
|
||||
require.Equal(t, "client_secret", settings["client_secret"])
|
||||
require.Equal(t, "client_secret_updated", copiedSettings["client_secret"])
|
||||
require.Equal(t, "bind_password_1", settings["config"].(map[string]any)["servers"].([]any)[0].(map[string]any)["bind_password"])
|
||||
require.Equal(t, "bind_password_1_updated", copiedSettings["config"].(map[string]any)["servers"].([]any)[0].(map[string]any)["bind_password"])
|
||||
|
||||
settingsWithRedactedSecrets := removeSecrets(settings)
|
||||
require.Equal(t, "client_secret", settings["client_secret"])
|
||||
require.Equal(t, "*********", settingsWithRedactedSecrets["client_secret"])
|
||||
require.Equal(t, "bind_password_1", settings["config"].(map[string]any)["servers"].([]any)[0].(map[string]any)["bind_password"])
|
||||
require.Equal(t, "*********", settingsWithRedactedSecrets["config"].(map[string]any)["servers"].([]any)[0].(map[string]any)["bind_password"])
|
||||
}
|
||||
|
||||
func TestService_DoReload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("successfully reload settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
settingsList := []*models.SSOSettings{
|
||||
{
|
||||
@ -1383,7 +1605,7 @@ func TestService_DoReload(t *testing.T) {
|
||||
t.Run("successfully reload settings when some providers have empty settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
settingsList := []*models.SSOSettings{
|
||||
{
|
||||
@ -1413,7 +1635,7 @@ func TestService_DoReload(t *testing.T) {
|
||||
t.Run("failed fetching the SSO settings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
provider := "github"
|
||||
|
||||
@ -1459,6 +1681,35 @@ func TestService_decryptSecrets(t *testing.T) {
|
||||
"certificate": "decrypted-certificate",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should decrypt LDAP secrets successfully",
|
||||
setup: func(env testEnv) {
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("client_key"), mock.Anything).Return([]byte("decrypted-client-key"), nil).Once()
|
||||
env.secrets.On("Decrypt", mock.Anything, []byte("bind_password"), mock.Anything).Return([]byte("decrypted-bind-password"), nil).Once()
|
||||
},
|
||||
settings: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"client_key": base64.RawStdEncoding.EncodeToString([]byte("client_key")),
|
||||
"bind_password": base64.RawStdEncoding.EncodeToString([]byte("bind_password")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"enabled": true,
|
||||
"config": map[string]any{
|
||||
"servers": []any{
|
||||
map[string]any{
|
||||
"client_key": "decrypted-client-key",
|
||||
"bind_password": "decrypted-bind-password",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should not decrypt when a secret is empty",
|
||||
setup: func(env testEnv) {
|
||||
@ -1514,7 +1765,7 @@ func TestService_decryptSecrets(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, false, false, false)
|
||||
env := setupTestEnv(t, false, false, false, false)
|
||||
|
||||
if tc.setup != nil {
|
||||
tc.setup(env)
|
||||
@ -1593,7 +1844,7 @@ func Test_ProviderService(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.samlEnabled)
|
||||
env := setupTestEnv(t, tc.isLicenseEnabled, true, tc.samlEnabled, false)
|
||||
|
||||
require.Equal(t, tc.expectedProvidersList, env.service.providersList)
|
||||
require.Equal(t, tc.strategiesLength, len(env.service.fbStrategies))
|
||||
@ -1601,7 +1852,7 @@ func Test_ProviderService(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies, samlEnabled bool) testEnv {
|
||||
func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies, samlEnabled bool, ldapEnabled bool) testEnv {
|
||||
t.Helper()
|
||||
|
||||
store := ssosettingstests.NewFakeStore()
|
||||
@ -1631,10 +1882,14 @@ func setupTestEnv(t *testing.T, isLicensingEnabled, keepFallbackStratergies, sam
|
||||
licensing := licensingtest.NewFakeLicensing()
|
||||
licensing.On("FeatureEnabled", "saml").Return(isLicensingEnabled)
|
||||
|
||||
featureManager := featuremgmt.WithManager()
|
||||
features := make([]any, 0)
|
||||
if samlEnabled {
|
||||
featureManager = featuremgmt.WithManager(featuremgmt.FlagSsoSettingsSAML)
|
||||
features = append(features, featuremgmt.FlagSsoSettingsSAML)
|
||||
}
|
||||
if ldapEnabled {
|
||||
features = append(features, featuremgmt.FlagSsoSettingsLDAP)
|
||||
}
|
||||
featureManager := featuremgmt.WithManager(features...)
|
||||
|
||||
svc := ProvideService(
|
||||
cfg,
|
||||
|
Loading…
Reference in New Issue
Block a user