From 617adb137c91ed81dfef935b7c996e72a26e5e58 Mon Sep 17 00:00:00 2001 From: Misi Date: Mon, 26 Feb 2024 15:33:29 +0100 Subject: [PATCH] Auth: OAuth strategy load extra fields separately (#83408) Load extra fields separately --- .../ssosettings/strategies/oauth_strategy.go | 25 ++++++-- .../strategies/oauth_strategy_test.go | 64 +++++++++++++++++++ 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/pkg/services/ssosettings/strategies/oauth_strategy.go b/pkg/services/ssosettings/strategies/oauth_strategy.go index 3dfe79a1bc3..22e3b2e4ba4 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy.go @@ -5,6 +5,7 @@ import ( "maps" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/connectors" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/setting" ) @@ -14,6 +15,14 @@ type OAuthStrategy struct { settingsByProvider map[string]map[string]any } +var extraKeysByProvider = map[string][]string{ + social.AzureADProviderName: connectors.ExtraAzureADSettingKeys, + social.GenericOAuthProviderName: connectors.ExtraGenericOAuthSettingKeys, + social.GitHubProviderName: connectors.ExtraGithubSettingKeys, + social.GrafanaComProviderName: connectors.ExtraGrafanaComSettingKeys, + social.GrafanaNetProviderName: connectors.ExtraGrafanaComSettingKeys, +} + var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil) func NewOAuthStrategy(cfg *setting.Cfg) *OAuthStrategy { @@ -39,7 +48,7 @@ func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (m } func (s *OAuthStrategy) loadAllSettings() { - allProviders := append(ssosettings.AllOAuthProviders, social.GrafanaNetProviderName) + allProviders := append([]string{social.GrafanaNetProviderName}, ssosettings.AllOAuthProviders...) for _, provider := range allProviders { settings := s.loadSettingsForProvider(provider) if provider == social.GrafanaNetProviderName { @@ -52,7 +61,7 @@ func (s *OAuthStrategy) loadAllSettings() { func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any { section := s.cfg.Raw.Section("auth." + provider) - return map[string]any{ + result := map[string]any{ "client_id": section.Key("client_id").Value(), "client_secret": section.Key("client_secret").Value(), "scopes": section.Key("scopes").Value(), @@ -85,10 +94,12 @@ func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any "auto_login": section.Key("auto_login").MustBool(false), "allowed_groups": section.Key("allowed_groups").Value(), "signout_redirect_url": section.Key("signout_redirect_url").Value(), - "allowed_organizations": section.Key("allowed_organizations").Value(), - "id_token_attribute_name": section.Key("id_token_attribute_name").Value(), - "login_attribute_path": section.Key("login_attribute_path").Value(), - "name_attribute_path": section.Key("name_attribute_path").Value(), - "team_ids": section.Key("team_ids").Value(), } + + extraFields := extraKeysByProvider[provider] + for _, key := range extraFields { + result[key] = section.Key(key).Value() + } + + return result } diff --git a/pkg/services/ssosettings/strategies/oauth_strategy_test.go b/pkg/services/ssosettings/strategies/oauth_strategy_test.go index 258ea41f966..172d143bb86 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy_test.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy_test.go @@ -108,3 +108,67 @@ func TestGetProviderConfig(t *testing.T) { require.Equal(t, expectedOAuthInfo, result) } + +func TestGetProviderConfig_ExtraFields(t *testing.T) { + iniWithExtraFields := ` + [auth.azuread] + force_use_graph_api = true + allowed_organizations = org1, org2 + + [auth.github] + team_ids = first, second + allowed_organizations = org1, org2 + + [auth.generic_oauth] + name_attribute_path = name + login_attribute_path = login + id_token_attribute_name = id_token + team_ids = first, second + allowed_organizations = org1, org2 + + [auth.grafana_com] + allowed_organizations = org1, org2 + ` + + iniFile, err := ini.Load([]byte(iniWithExtraFields)) + require.NoError(t, err) + + cfg := setting.NewCfg() + cfg.Raw = iniFile + + strategy := NewOAuthStrategy(cfg) + + t.Run("azuread", func(t *testing.T) { + result, err := strategy.GetProviderConfig(context.Background(), "azuread") + require.NoError(t, err) + + require.Equal(t, "true", result["force_use_graph_api"]) + require.Equal(t, "org1, org2", result["allowed_organizations"]) + }) + + t.Run("github", func(t *testing.T) { + result, err := strategy.GetProviderConfig(context.Background(), "github") + require.NoError(t, err) + + require.Equal(t, "first, second", result["team_ids"]) + require.Equal(t, "org1, org2", result["allowed_organizations"]) + }) + + t.Run("generic_oauth", func(t *testing.T) { + result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth") + require.NoError(t, err) + + require.Equal(t, "first, second", result["team_ids"]) + require.Equal(t, "org1, org2", result["allowed_organizations"]) + require.Equal(t, "name", result["name_attribute_path"]) + require.Equal(t, "login", result["login_attribute_path"]) + require.Equal(t, "id_token", result["id_token_attribute_name"]) + }) + + t.Run("grafana_com", func(t *testing.T) { + result, err := strategy.GetProviderConfig(context.Background(), "grafana_com") + require.NoError(t, err) + + require.Equal(t, "org1, org2", result["allowed_organizations"]) + }) +}