Auth: OAuth strategy load extra fields separately (#83408)

Load extra fields separately
This commit is contained in:
Misi 2024-02-26 15:33:29 +01:00 committed by GitHub
parent 1f484fef9d
commit 617adb137c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 7 deletions

View File

@ -5,6 +5,7 @@ import (
"maps"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/connectors"
"github.com/grafana/grafana/pkg/services/ssosettings"
"github.com/grafana/grafana/pkg/setting"
)
@ -14,6 +15,14 @@ type OAuthStrategy struct {
settingsByProvider map[string]map[string]any
}
var extraKeysByProvider = map[string][]string{
social.AzureADProviderName: connectors.ExtraAzureADSettingKeys,
social.GenericOAuthProviderName: connectors.ExtraGenericOAuthSettingKeys,
social.GitHubProviderName: connectors.ExtraGithubSettingKeys,
social.GrafanaComProviderName: connectors.ExtraGrafanaComSettingKeys,
social.GrafanaNetProviderName: connectors.ExtraGrafanaComSettingKeys,
}
var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil)
func NewOAuthStrategy(cfg *setting.Cfg) *OAuthStrategy {
@ -39,7 +48,7 @@ func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (m
}
func (s *OAuthStrategy) loadAllSettings() {
allProviders := append(ssosettings.AllOAuthProviders, social.GrafanaNetProviderName)
allProviders := append([]string{social.GrafanaNetProviderName}, ssosettings.AllOAuthProviders...)
for _, provider := range allProviders {
settings := s.loadSettingsForProvider(provider)
if provider == social.GrafanaNetProviderName {
@ -52,7 +61,7 @@ func (s *OAuthStrategy) loadAllSettings() {
func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any {
section := s.cfg.Raw.Section("auth." + provider)
return map[string]any{
result := map[string]any{
"client_id": section.Key("client_id").Value(),
"client_secret": section.Key("client_secret").Value(),
"scopes": section.Key("scopes").Value(),
@ -85,10 +94,12 @@ func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any
"auto_login": section.Key("auto_login").MustBool(false),
"allowed_groups": section.Key("allowed_groups").Value(),
"signout_redirect_url": section.Key("signout_redirect_url").Value(),
"allowed_organizations": section.Key("allowed_organizations").Value(),
"id_token_attribute_name": section.Key("id_token_attribute_name").Value(),
"login_attribute_path": section.Key("login_attribute_path").Value(),
"name_attribute_path": section.Key("name_attribute_path").Value(),
"team_ids": section.Key("team_ids").Value(),
}
extraFields := extraKeysByProvider[provider]
for _, key := range extraFields {
result[key] = section.Key(key).Value()
}
return result
}

View File

@ -108,3 +108,67 @@ func TestGetProviderConfig(t *testing.T) {
require.Equal(t, expectedOAuthInfo, result)
}
func TestGetProviderConfig_ExtraFields(t *testing.T) {
iniWithExtraFields := `
[auth.azuread]
force_use_graph_api = true
allowed_organizations = org1, org2
[auth.github]
team_ids = first, second
allowed_organizations = org1, org2
[auth.generic_oauth]
name_attribute_path = name
login_attribute_path = login
id_token_attribute_name = id_token
team_ids = first, second
allowed_organizations = org1, org2
[auth.grafana_com]
allowed_organizations = org1, org2
`
iniFile, err := ini.Load([]byte(iniWithExtraFields))
require.NoError(t, err)
cfg := setting.NewCfg()
cfg.Raw = iniFile
strategy := NewOAuthStrategy(cfg)
t.Run("azuread", func(t *testing.T) {
result, err := strategy.GetProviderConfig(context.Background(), "azuread")
require.NoError(t, err)
require.Equal(t, "true", result["force_use_graph_api"])
require.Equal(t, "org1, org2", result["allowed_organizations"])
})
t.Run("github", func(t *testing.T) {
result, err := strategy.GetProviderConfig(context.Background(), "github")
require.NoError(t, err)
require.Equal(t, "first, second", result["team_ids"])
require.Equal(t, "org1, org2", result["allowed_organizations"])
})
t.Run("generic_oauth", func(t *testing.T) {
result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth")
require.NoError(t, err)
require.Equal(t, "first, second", result["team_ids"])
require.Equal(t, "org1, org2", result["allowed_organizations"])
require.Equal(t, "name", result["name_attribute_path"])
require.Equal(t, "login", result["login_attribute_path"])
require.Equal(t, "id_token", result["id_token_attribute_name"])
})
t.Run("grafana_com", func(t *testing.T) {
result, err := strategy.GetProviderConfig(context.Background(), "grafana_com")
require.NoError(t, err)
require.Equal(t, "org1, org2", result["allowed_organizations"])
})
}