mirror of
https://github.com/grafana/grafana.git
synced 2025-02-10 07:35:45 -06:00
Auth: OAuth strategy load extra fields separately (#83408)
Load extra fields separately
This commit is contained in:
parent
1f484fef9d
commit
617adb137c
@ -5,6 +5,7 @@ import (
|
||||
"maps"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/social/connectors"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
@ -14,6 +15,14 @@ type OAuthStrategy struct {
|
||||
settingsByProvider map[string]map[string]any
|
||||
}
|
||||
|
||||
var extraKeysByProvider = map[string][]string{
|
||||
social.AzureADProviderName: connectors.ExtraAzureADSettingKeys,
|
||||
social.GenericOAuthProviderName: connectors.ExtraGenericOAuthSettingKeys,
|
||||
social.GitHubProviderName: connectors.ExtraGithubSettingKeys,
|
||||
social.GrafanaComProviderName: connectors.ExtraGrafanaComSettingKeys,
|
||||
social.GrafanaNetProviderName: connectors.ExtraGrafanaComSettingKeys,
|
||||
}
|
||||
|
||||
var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil)
|
||||
|
||||
func NewOAuthStrategy(cfg *setting.Cfg) *OAuthStrategy {
|
||||
@ -39,7 +48,7 @@ func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (m
|
||||
}
|
||||
|
||||
func (s *OAuthStrategy) loadAllSettings() {
|
||||
allProviders := append(ssosettings.AllOAuthProviders, social.GrafanaNetProviderName)
|
||||
allProviders := append([]string{social.GrafanaNetProviderName}, ssosettings.AllOAuthProviders...)
|
||||
for _, provider := range allProviders {
|
||||
settings := s.loadSettingsForProvider(provider)
|
||||
if provider == social.GrafanaNetProviderName {
|
||||
@ -52,7 +61,7 @@ func (s *OAuthStrategy) loadAllSettings() {
|
||||
func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any {
|
||||
section := s.cfg.Raw.Section("auth." + provider)
|
||||
|
||||
return map[string]any{
|
||||
result := map[string]any{
|
||||
"client_id": section.Key("client_id").Value(),
|
||||
"client_secret": section.Key("client_secret").Value(),
|
||||
"scopes": section.Key("scopes").Value(),
|
||||
@ -85,10 +94,12 @@ func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any
|
||||
"auto_login": section.Key("auto_login").MustBool(false),
|
||||
"allowed_groups": section.Key("allowed_groups").Value(),
|
||||
"signout_redirect_url": section.Key("signout_redirect_url").Value(),
|
||||
"allowed_organizations": section.Key("allowed_organizations").Value(),
|
||||
"id_token_attribute_name": section.Key("id_token_attribute_name").Value(),
|
||||
"login_attribute_path": section.Key("login_attribute_path").Value(),
|
||||
"name_attribute_path": section.Key("name_attribute_path").Value(),
|
||||
"team_ids": section.Key("team_ids").Value(),
|
||||
}
|
||||
|
||||
extraFields := extraKeysByProvider[provider]
|
||||
for _, key := range extraFields {
|
||||
result[key] = section.Key(key).Value()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
@ -108,3 +108,67 @@ func TestGetProviderConfig(t *testing.T) {
|
||||
|
||||
require.Equal(t, expectedOAuthInfo, result)
|
||||
}
|
||||
|
||||
func TestGetProviderConfig_ExtraFields(t *testing.T) {
|
||||
iniWithExtraFields := `
|
||||
[auth.azuread]
|
||||
force_use_graph_api = true
|
||||
allowed_organizations = org1, org2
|
||||
|
||||
[auth.github]
|
||||
team_ids = first, second
|
||||
allowed_organizations = org1, org2
|
||||
|
||||
[auth.generic_oauth]
|
||||
name_attribute_path = name
|
||||
login_attribute_path = login
|
||||
id_token_attribute_name = id_token
|
||||
team_ids = first, second
|
||||
allowed_organizations = org1, org2
|
||||
|
||||
[auth.grafana_com]
|
||||
allowed_organizations = org1, org2
|
||||
`
|
||||
|
||||
iniFile, err := ini.Load([]byte(iniWithExtraFields))
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := setting.NewCfg()
|
||||
cfg.Raw = iniFile
|
||||
|
||||
strategy := NewOAuthStrategy(cfg)
|
||||
|
||||
t.Run("azuread", func(t *testing.T) {
|
||||
result, err := strategy.GetProviderConfig(context.Background(), "azuread")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "true", result["force_use_graph_api"])
|
||||
require.Equal(t, "org1, org2", result["allowed_organizations"])
|
||||
})
|
||||
|
||||
t.Run("github", func(t *testing.T) {
|
||||
result, err := strategy.GetProviderConfig(context.Background(), "github")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "first, second", result["team_ids"])
|
||||
require.Equal(t, "org1, org2", result["allowed_organizations"])
|
||||
})
|
||||
|
||||
t.Run("generic_oauth", func(t *testing.T) {
|
||||
result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "first, second", result["team_ids"])
|
||||
require.Equal(t, "org1, org2", result["allowed_organizations"])
|
||||
require.Equal(t, "name", result["name_attribute_path"])
|
||||
require.Equal(t, "login", result["login_attribute_path"])
|
||||
require.Equal(t, "id_token", result["id_token_attribute_name"])
|
||||
})
|
||||
|
||||
t.Run("grafana_com", func(t *testing.T) {
|
||||
result, err := strategy.GetProviderConfig(context.Background(), "grafana_com")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "org1, org2", result["allowed_organizations"])
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user