From 20bb0a3ab173fc8382e87bcf9c492a684d9c63d7 Mon Sep 17 00:00:00 2001 From: Misi Date: Mon, 22 Jan 2024 14:54:48 +0100 Subject: [PATCH] AuthN: Support reloading SSO config after the sso settings have changed (#80734) * Add AuthNSvc reload handling * Working, need to add test * Remove commented out code * Add Reload implementation to connectors * Align and add tests, refactor * Add more tests, linting * Add extra checks + tests to oauth client * Clean up based on reviews * Move config instantiation into newSocialBase * Use specific error --- pkg/login/social/connectors/azuread_oauth.go | 26 +++- .../social/connectors/azuread_oauth_test.go | 83 ++++++++++- pkg/login/social/connectors/generic_oauth.go | 28 +++- .../social/connectors/generic_oauth_test.go | 134 +++++++++++++++++- pkg/login/social/connectors/github_oauth.go | 35 ++++- .../social/connectors/github_oauth_test.go | 85 ++++++++++- pkg/login/social/connectors/gitlab_oauth.go | 17 ++- .../social/connectors/gitlab_oauth_test.go | 32 +++-- pkg/login/social/connectors/google_oauth.go | 23 ++- .../social/connectors/google_oauth_test.go | 30 +++- .../social/connectors/grafana_com_oauth.go | 25 +++- .../connectors/grafana_com_oauth_test.go | 97 ++++++++++++- pkg/login/social/connectors/okta_oauth.go | 22 ++- .../social/connectors/okta_oauth_test.go | 30 +++- pkg/login/social/connectors/social_base.go | 50 +++---- pkg/services/authn/authnimpl/service.go | 14 +- pkg/services/authn/clients/oauth.go | 83 +++++++---- pkg/services/authn/clients/oauth_test.go | 85 +++++++---- pkg/services/ssosettings/api/api.go | 2 +- pkg/services/ssosettings/api/api_test.go | 2 +- pkg/services/ssosettings/database/database.go | 13 +- .../ssosettings/database/database_test.go | 10 +- pkg/services/ssosettings/errors.go | 6 +- pkg/services/ssosettings/ssosettings.go | 4 +- .../ssosettings/ssosettingsimpl/service.go | 69 +++++++-- .../ssosettingsimpl/service_test.go | 75 ++++++++-- .../ssosettingstests/reloadable_mock.go | 2 +- .../ssosettingstests/service_mock.go | 6 +- .../ssosettingstests/store_fake.go | 6 +- .../ssosettingstests/store_mock.go | 6 +- .../ssosettings/strategies/oauth_strategy.go | 6 +- 31 files changed, 889 insertions(+), 217 deletions(-) diff --git a/pkg/login/social/connectors/azuread_oauth.go b/pkg/login/social/connectors/azuread_oauth.go index 86b22c042da..98f21607146 100644 --- a/pkg/login/social/connectors/azuread_oauth.go +++ b/pkg/login/social/connectors/azuread_oauth.go @@ -73,16 +73,15 @@ type keySetJWKS struct { } func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD { - config := createOAuthConfig(info, cfg, social.AzureADProviderName) provider := &SocialAzureAD{ - SocialBase: newSocialBase(social.AzureADProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.AzureADProviderName, info, features, cfg), cache: cache, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false), } if info.UseRefreshToken { - appendUniqueScope(config, social.OfflineAccessScope) + appendUniqueScope(provider.Config, social.OfflineAccessScope) } if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { @@ -164,6 +163,27 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token }, nil } +func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.AzureADProviderName, newInfo, s.features, s.cfg) + + if newInfo.UseRefreshToken { + appendUniqueScope(s.Config, social.OfflineAccessScope) + } + + s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey]) + s.forceUseGraphAPI = MustBool(newInfo.Extra[forceUseGraphAPIKey], false) + + return nil +} + func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { info, err := CreateOAuthInfoFromKeyValues(settings.Settings) if err != nil { diff --git a/pkg/login/social/connectors/azuread_oauth_test.go b/pkg/login/social/connectors/azuread_oauth_test.go index 18191e169c0..aed48e98d9b 100644 --- a/pkg/login/social/connectors/azuread_oauth_test.go +++ b/pkg/login/social/connectors/azuread_oauth_test.go @@ -1048,11 +1048,12 @@ func TestSocialAzureAD_Validate(t *testing.T) { func TestSocialAzureAD_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -1073,6 +1074,14 @@ func TestSocialAzureAD_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/azuread", + }, }, { name: "fails if settings contain invalid values", @@ -1092,6 +1101,11 @@ func TestSocialAzureAD_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/azuread", + }, }, } @@ -1102,10 +1116,65 @@ func TestSocialAzureAD_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) + }) + } +} + +func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) { + testCases := []struct { + name string + settings ssoModels.SSOSettings + info *social.OAuthInfo + expectError bool + expectedInfo *social.OAuthInfo + expectedAllowedOrganizations []string + expectedForceUseGraphApi bool + }{ + { + name: "successfully reloads the settings", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + Extra: map[string]string{ + "allowed_organizations": "previous", + "force_use_graph_api": "true", + }, + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "allowed_organizations": "uuid-1234,uuid-5678", + "force_use_graph_api": "false", + }, + }, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + Name: "a-new-name", + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + "force_use_graph_api": "false", + }, + }, + expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"}, + expectedForceUseGraphApi: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewAzureADProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{}) + + err := s.Reload(context.Background(), tc.settings) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations) + require.EqualValues(t, tc.expectedForceUseGraphApi, s.forceUseGraphAPI) }) } } diff --git a/pkg/login/social/connectors/generic_oauth.go b/pkg/login/social/connectors/generic_oauth.go index fcc8f33ca56..87dce610fc7 100644 --- a/pkg/login/social/connectors/generic_oauth.go +++ b/pkg/login/social/connectors/generic_oauth.go @@ -46,9 +46,8 @@ type SocialGenericOAuth struct { } func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth { - config := createOAuthConfig(info, cfg, social.GenericOAuthProviderName) provider := &SocialGenericOAuth{ - SocialBase: newSocialBase(social.GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.GenericOAuthProviderName, info, features, cfg), teamsUrl: info.TeamsUrl, emailAttributeName: info.EmailAttributeName, emailAttributePath: info.EmailAttributePath, @@ -84,6 +83,31 @@ func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SS return nil } +func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.GenericOAuthProviderName, newInfo, s.features, s.cfg) + + s.teamsUrl = newInfo.TeamsUrl + s.emailAttributeName = newInfo.EmailAttributeName + s.emailAttributePath = newInfo.EmailAttributePath + s.nameAttributePath = newInfo.Extra[nameAttributePathKey] + s.groupsAttributePath = newInfo.GroupsAttributePath + s.loginAttributePath = newInfo.Extra[loginAttributePathKey] + s.idTokenAttributeName = newInfo.Extra[idTokenAttributeNameKey] + s.teamIdsAttributePath = newInfo.TeamIdsAttributePath + s.teamIds = util.SplitString(newInfo.Extra[teamIdsKey]) + s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey]) + + return nil +} + // TODOD: remove this in the next PR and use the isGroupMember from social.go func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool { info := s.GetOAuthInfo() diff --git a/pkg/login/social/connectors/generic_oauth_test.go b/pkg/login/social/connectors/generic_oauth_test.go index d19bcabd63f..bf6dc50c366 100644 --- a/pkg/login/social/connectors/generic_oauth_test.go +++ b/pkg/login/social/connectors/generic_oauth_test.go @@ -976,11 +976,12 @@ func TestSocialGenericOAuth_Validate(t *testing.T) { func TestSocialGenericOAuth_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -1001,6 +1002,14 @@ func TestSocialGenericOAuth_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/generic_oauth", + }, }, { name: "fails if settings contain invalid values", @@ -1020,6 +1029,11 @@ func TestSocialGenericOAuth_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/generic_oauth", + }, }, } @@ -1030,10 +1044,116 @@ func TestSocialGenericOAuth_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) + }) + } +} + +func TestGenericOAuth_Reload_ExtraFields(t *testing.T) { + testCases := []struct { + name string + settings ssoModels.SSOSettings + info *social.OAuthInfo + expectError bool + expectedInfo *social.OAuthInfo + expectedTeamsUrl string + expectedEmailAttributeName string + expectedEmailAttributePath string + expectedNameAttributePath string + expectedGroupsAttributePath string + expectedLoginAttributePath string + expectedIdTokenAttributeName string + expectedTeamIdsAttributePath string + expectedTeamIds []string + expectedAllowedOrganizations []string + }{ + { + name: "successfully reloads the settings", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + TeamsUrl: "https://host/users", + EmailAttributePath: "email-attr-path", + EmailAttributeName: "email-attr-name", + GroupsAttributePath: "groups-attr-path", + TeamIdsAttributePath: "team-ids-attr-path", + Extra: map[string]string{ + teamIdsKey: "team1", + allowedOrganizationsKey: "org1", + loginAttributePathKey: "login-attr-path", + idTokenAttributeNameKey: "id-token-attr-name", + nameAttributePathKey: "name-attr-path", + }, + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "teams_url": "https://host/v2/users", + "email_attribute_path": "new-email-attr-path", + "email_attribute_name": "new-email-attr-name", + "groups_attribute_path": "new-group-attr-path", + "team_ids_attribute_path": "new-team-ids-attr-path", + teamIdsKey: "team1,team2", + allowedOrganizationsKey: "org1,org2", + loginAttributePathKey: "new-login-attr-path", + idTokenAttributeNameKey: "new-id-token-attr-name", + nameAttributePathKey: "new-name-attr-path", + }, + }, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + TeamsUrl: "https://host/v2/users", + EmailAttributePath: "new-email-attr-path", + EmailAttributeName: "new-email-attr-name", + GroupsAttributePath: "new-group-attr-path", + TeamIdsAttributePath: "new-team-ids-attr-path", + Extra: map[string]string{ + teamIdsKey: "team1,team2", + allowedOrganizationsKey: "org1,org2", + loginAttributePathKey: "new-login-attr-path", + idTokenAttributeNameKey: "new-id-token-attr-name", + nameAttributePathKey: "new-name-attr-path", + }, + }, + expectedTeamsUrl: "https://host/v2/users", + expectedEmailAttributeName: "new-email-attr-name", + expectedEmailAttributePath: "new-email-attr-path", + expectedGroupsAttributePath: "new-group-attr-path", + expectedTeamIdsAttributePath: "new-team-ids-attr-path", + expectedTeamIds: []string{"team1", "team2"}, + expectedAllowedOrganizations: []string{"org1", "org2"}, + expectedLoginAttributePath: "new-login-attr-path", + expectedIdTokenAttributeName: "new-id-token-attr-name", + expectedNameAttributePath: "new-name-attr-path", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGenericOAuthProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedInfo, s.info) + + require.EqualValues(t, tc.expectedTeamsUrl, s.teamsUrl) + require.EqualValues(t, tc.expectedEmailAttributeName, s.emailAttributeName) + require.EqualValues(t, tc.expectedEmailAttributePath, s.emailAttributePath) + require.EqualValues(t, tc.expectedGroupsAttributePath, s.groupsAttributePath) + require.EqualValues(t, tc.expectedTeamIdsAttributePath, s.teamIdsAttributePath) + require.EqualValues(t, tc.expectedTeamIds, s.teamIds) + require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations) + require.EqualValues(t, tc.expectedLoginAttributePath, s.loginAttributePath) + require.EqualValues(t, tc.expectedIdTokenAttributeName, s.idTokenAttributeName) + require.EqualValues(t, tc.expectedNameAttributePath, s.nameAttributePath) }) } } diff --git a/pkg/login/social/connectors/github_oauth.go b/pkg/login/social/connectors/github_oauth.go index 749596df075..4a89bac5eed 100644 --- a/pkg/login/social/connectors/github_oauth.go +++ b/pkg/login/social/connectors/github_oauth.go @@ -57,9 +57,8 @@ func NewGitHubProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings sso teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey]) teamIds := mustInts(teamIdsSplitted) - config := createOAuthConfig(info, cfg, social.GitHubProviderName) provider := &SocialGithub{ - SocialBase: newSocialBase(social.GitHubProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.GitHubProviderName, info, features, cfg), teamIds: teamIds, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), } @@ -86,7 +85,37 @@ func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSetti return err } - // add specific validation rules for Github + teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey]) + teamIds := mustInts(teamIdsSplitted) + + if len(teamIdsSplitted) != len(teamIds) { + s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted) + return ssosettings.ErrInvalidSettings.Errorf("Failed to parse team ids. Team ids must be a list of numbers.") + } + + return nil +} + +func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + teamIdsSplitted := util.SplitString(newInfo.Extra[teamIdsKey]) + teamIds := mustInts(teamIdsSplitted) + + if len(teamIdsSplitted) != len(teamIds) { + s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted) + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.GitHubProviderName, newInfo, s.features, s.cfg) + + s.teamIds = teamIds + s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey]) return nil } diff --git a/pkg/login/social/connectors/github_oauth_test.go b/pkg/login/social/connectors/github_oauth_test.go index 9b6613a0e0a..974969518b9 100644 --- a/pkg/login/social/connectors/github_oauth_test.go +++ b/pkg/login/social/connectors/github_oauth_test.go @@ -402,11 +402,12 @@ func TestSocialGitHub_Validate(t *testing.T) { func TestSocialGitHub_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -427,6 +428,14 @@ func TestSocialGitHub_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/github", + }, }, { name: "fails if settings contain invalid values", @@ -446,6 +455,11 @@ func TestSocialGitHub_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/github", + }, }, } @@ -456,10 +470,67 @@ func TestSocialGitHub_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) + }) + } +} + +func TestGitHub_Reload_ExtraFields(t *testing.T) { + testCases := []struct { + name string + settings ssoModels.SSOSettings + info *social.OAuthInfo + expectError bool + expectedInfo *social.OAuthInfo + expectedAllowedOrganizations []string + expectedTeamIds []int + }{ + { + name: "successfully reloads the settings", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + Extra: map[string]string{ + "allowed_organizations": "previous", + "team_ids": "", + }, + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "allowed_organizations": "uuid-1234,uuid-5678", + "team_ids": "123,456", + }, + }, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + Name: "a-new-name", + AuthStyle: "inheader", + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + "force_use_graph_api": "false", + }, + }, + expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"}, + expectedTeamIds: []int{123, 456}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGitHubProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations) + require.EqualValues(t, tc.expectedTeamIds, s.teamIds) }) } } diff --git a/pkg/login/social/connectors/gitlab_oauth.go b/pkg/login/social/connectors/gitlab_oauth.go index fbdafd557dc..ee59a8b6e92 100644 --- a/pkg/login/social/connectors/gitlab_oauth.go +++ b/pkg/login/social/connectors/gitlab_oauth.go @@ -53,9 +53,8 @@ type userData struct { } func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab { - config := createOAuthConfig(info, cfg, social.GitlabProviderName) provider := &SocialGitlab{ - SocialBase: newSocialBase(social.GitlabProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.GitlabProviderName, info, features, cfg), } if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { @@ -76,7 +75,19 @@ func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSetti return err } - // add specific validation rules for Gitlab + return nil +} + +func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.GitlabProviderName, newInfo, s.features, s.cfg) return nil } diff --git a/pkg/login/social/connectors/gitlab_oauth_test.go b/pkg/login/social/connectors/gitlab_oauth_test.go index 1eeac7a7f86..00a86153159 100644 --- a/pkg/login/social/connectors/gitlab_oauth_test.go +++ b/pkg/login/social/connectors/gitlab_oauth_test.go @@ -164,7 +164,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) { for _, test := range tests { provider.info.RoleAttributePath = test.RoleAttributePath provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin - provider.autoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole) + provider.cfg.AutoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole) provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync @@ -520,11 +520,12 @@ func TestSocialGitlab_Validate(t *testing.T) { func TestSocialGitlab_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -545,6 +546,14 @@ func TestSocialGitlab_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/gitlab", + }, }, { name: "fails if settings contain invalid values", @@ -564,6 +573,11 @@ func TestSocialGitlab_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/gitlab", + }, }, } @@ -574,10 +588,12 @@ func TestSocialGitlab_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) }) } } diff --git a/pkg/login/social/connectors/google_oauth.go b/pkg/login/social/connectors/google_oauth.go index cf74c150d5c..c4c731068ca 100644 --- a/pkg/login/social/connectors/google_oauth.go +++ b/pkg/login/social/connectors/google_oauth.go @@ -39,9 +39,8 @@ type googleUserData struct { } func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle { - config := createOAuthConfig(info, cfg, social.GoogleProviderName) provider := &SocialGoogle{ - SocialBase: newSocialBase(social.GoogleProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.GoogleProviderName, info, features, cfg), } if strings.HasPrefix(info.ApiUrl, legacyAPIURL) { @@ -71,6 +70,24 @@ func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSetti return nil } +func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + if strings.HasPrefix(newInfo.ApiUrl, legacyAPIURL) { + s.log.Warn("Using legacy Google API URL, please update your configuration") + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.GoogleProviderName, newInfo, s.features, s.cfg) + + return nil +} + func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { info := s.GetOAuthInfo() @@ -225,7 +242,7 @@ type googleGroupResp struct { } func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) { - s.log.Debug("Retrieving groups", "scopes", s.SocialBase.Config.Scopes) + s.log.Debug("Retrieving groups", "scopes", s.Config.Scopes) if !slices.Contains(s.Scopes, googleIAMScope) { return nil, nil } diff --git a/pkg/login/social/connectors/google_oauth_test.go b/pkg/login/social/connectors/google_oauth_test.go index 5cd5627d589..79a4cb9d8cf 100644 --- a/pkg/login/social/connectors/google_oauth_test.go +++ b/pkg/login/social/connectors/google_oauth_test.go @@ -725,11 +725,12 @@ func TestSocialGoogle_Validate(t *testing.T) { func TestSocialGoogle_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -750,6 +751,14 @@ func TestSocialGoogle_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/google", + }, }, { name: "fails if settings contain invalid values", @@ -769,6 +778,11 @@ func TestSocialGoogle_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/google", + }, }, } @@ -779,10 +793,12 @@ func TestSocialGoogle_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) }) } } diff --git a/pkg/login/social/connectors/grafana_com_oauth.go b/pkg/login/social/connectors/grafana_com_oauth.go index ced97b6459b..60c3d23e51e 100644 --- a/pkg/login/social/connectors/grafana_com_oauth.go +++ b/pkg/login/social/connectors/grafana_com_oauth.go @@ -39,9 +39,8 @@ func NewGrafanaComProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token" info.AuthStyle = "inheader" - config := createOAuthConfig(info, cfg, social.GrafanaComProviderName) provider := &SocialGrafanaCom{ - SocialBase: newSocialBase(social.GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.GrafanaComProviderName, info, features, cfg), url: cfg.GrafanaComURL, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), } @@ -69,6 +68,28 @@ func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOS return nil } +func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + // Override necessary settings + newInfo.AuthUrl = s.cfg.GrafanaComURL + "/oauth2/authorize" + newInfo.TokenUrl = s.cfg.GrafanaComURL + "/api/oauth2/token" + newInfo.AuthStyle = "inheader" + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.GrafanaComProviderName, newInfo, s.features, s.cfg) + + s.url = s.cfg.GrafanaComURL + s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey]) + + return nil +} + func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { return true } diff --git a/pkg/login/social/connectors/grafana_com_oauth_test.go b/pkg/login/social/connectors/grafana_com_oauth_test.go index 340650f5079..bdcf15e081b 100644 --- a/pkg/login/social/connectors/grafana_com_oauth_test.go +++ b/pkg/login/social/connectors/grafana_com_oauth_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" @@ -195,11 +196,12 @@ func TestSocialGrafanaCom_Reload(t *testing.T) { const GrafanaComURL = "http://localhost:3000" testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -219,6 +221,19 @@ func TestSocialGrafanaCom_Reload(t *testing.T) { ClientId: "new-client-id", ClientSecret: "new-client-secret", Name: "a-new-name", + AuthUrl: GrafanaComURL + "/oauth2/authorize", + TokenUrl: GrafanaComURL + "/api/oauth2/token", + AuthStyle: "inheader", + }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: GrafanaComURL + "/oauth2/authorize", + TokenURL: GrafanaComURL + "/api/oauth2/token", + AuthStyle: oauth2.AuthStyleInHeader, + }, + RedirectURL: "/login/grafana_com", }, }, { @@ -243,6 +258,16 @@ func TestSocialGrafanaCom_Reload(t *testing.T) { TokenUrl: GrafanaComURL + "/api/oauth2/token", AuthStyle: "inheader", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: GrafanaComURL + "/oauth2/authorize", + TokenURL: GrafanaComURL + "/api/oauth2/token", + AuthStyle: oauth2.AuthStyleInHeader, + }, + RedirectURL: "/login/grafana_com", + }, }, } @@ -256,10 +281,68 @@ func TestSocialGrafanaCom_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) + }) + } +} + +func TestSocialGrafanaCom_Reload_ExtraFields(t *testing.T) { + const GrafanaComURL = "http://localhost:3000" + + testCases := []struct { + name string + settings ssoModels.SSOSettings + info *social.OAuthInfo + expectError bool + expectedInfo *social.OAuthInfo + expectedAllowedOrganizations []string + }{ + { + name: "successfully reloads the allowed organizations when they are set in the settings", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + Extra: map[string]string{ + "allowed_organizations": "previous", + }, + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, + }, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + Name: "a-new-name", + AuthUrl: GrafanaComURL + "/oauth2/authorize", + TokenUrl: GrafanaComURL + "/api/oauth2/token", + AuthStyle: "inheader", + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, + }, + expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cfg := &setting.Cfg{ + GrafanaComURL: GrafanaComURL, + } + s := NewGrafanaComProvider(tc.info, cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations) }) } } diff --git a/pkg/login/social/connectors/okta_oauth.go b/pkg/login/social/connectors/okta_oauth.go index aab500562b4..39eb41a1bc6 100644 --- a/pkg/login/social/connectors/okta_oauth.go +++ b/pkg/login/social/connectors/okta_oauth.go @@ -45,13 +45,12 @@ type OktaClaims struct { } func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta { - config := createOAuthConfig(info, cfg, social.OktaProviderName) provider := &SocialOkta{ - SocialBase: newSocialBase(social.OktaProviderName, config, info, cfg.AutoAssignOrgRole, features), + SocialBase: newSocialBase(social.OktaProviderName, info, features, cfg), } if info.UseRefreshToken { - appendUniqueScope(config, social.OfflineAccessScope) + appendUniqueScope(provider.Config, social.OfflineAccessScope) } if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { @@ -77,6 +76,23 @@ func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSetting return nil } +func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + s.reloadMutex.Lock() + defer s.reloadMutex.Unlock() + + s.SocialBase = newSocialBase(social.OktaProviderName, newInfo, s.features, s.cfg) + if newInfo.UseRefreshToken { + appendUniqueScope(s.Config, social.OfflineAccessScope) + } + + return nil +} + func (claims *OktaClaims) extractEmail() string { if claims.Email == "" && claims.PreferredUsername != "" { return claims.PreferredUsername diff --git a/pkg/login/social/connectors/okta_oauth_test.go b/pkg/login/social/connectors/okta_oauth_test.go index f0d1ca82fe8..6e332a0a1df 100644 --- a/pkg/login/social/connectors/okta_oauth_test.go +++ b/pkg/login/social/connectors/okta_oauth_test.go @@ -193,11 +193,12 @@ func TestSocialOkta_Validate(t *testing.T) { func TestSocialOkta_Reload(t *testing.T) { testCases := []struct { - name string - info *social.OAuthInfo - settings ssoModels.SSOSettings - expectError bool - expectedInfo *social.OAuthInfo + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + expectedConfig *oauth2.Config }{ { name: "SSO provider successfully updated", @@ -218,6 +219,14 @@ func TestSocialOkta_Reload(t *testing.T) { ClientSecret: "new-client-secret", AuthUrl: "some-new-url", }, + expectedConfig: &oauth2.Config{ + ClientID: "new-client-id", + ClientSecret: "new-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "some-new-url", + }, + RedirectURL: "/login/okta", + }, }, { name: "fails if settings contain invalid values", @@ -237,6 +246,11 @@ func TestSocialOkta_Reload(t *testing.T) { ClientId: "client-id", ClientSecret: "client-secret", }, + expectedConfig: &oauth2.Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURL: "/login/okta", + }, }, } @@ -247,10 +261,12 @@ func TestSocialOkta_Reload(t *testing.T) { err := s.Reload(context.Background(), tc.settings) if tc.expectError { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) + require.EqualValues(t, tc.expectedInfo, s.info) + require.EqualValues(t, tc.expectedConfig, s.Config) }) } } diff --git a/pkg/login/social/connectors/social_base.go b/pkg/login/social/connectors/social_base.go index 62c4e8247b2..10f990f5600 100644 --- a/pkg/login/social/connectors/social_base.go +++ b/pkg/login/social/connectors/social_base.go @@ -3,7 +3,6 @@ package connectors import ( "bytes" "compress/zlib" - "context" "encoding/base64" "encoding/json" "fmt" @@ -21,32 +20,31 @@ import ( "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/ssosettings" - ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" + "github.com/grafana/grafana/pkg/setting" ) type SocialBase struct { *oauth2.Config - info *social.OAuthInfo - infoMutex sync.RWMutex - log log.Logger - autoAssignOrgRole string - features featuremgmt.FeatureToggles + info *social.OAuthInfo + cfg *setting.Cfg + reloadMutex sync.RWMutex + log log.Logger + features featuremgmt.FeatureToggles } func newSocialBase(name string, - config *oauth2.Config, info *social.OAuthInfo, - autoAssignOrgRole string, features featuremgmt.FeatureToggles, + cfg *setting.Cfg, ) *SocialBase { logger := log.New("oauth." + name) return &SocialBase{ - Config: config, - info: info, - log: logger, - autoAssignOrgRole: autoAssignOrgRole, - features: features, + Config: createOAuthConfig(info, cfg, name), + info: info, + log: logger, + features: features, + cfg: cfg, } } @@ -60,7 +58,7 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error { bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin)) bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup)) bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains)) - bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole)) + bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.cfg.AutoAssignOrgRole)) bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath)) bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict)) bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync)) @@ -76,26 +74,12 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error { } func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo { - s.infoMutex.RLock() - defer s.infoMutex.RUnlock() + s.reloadMutex.RLock() + defer s.reloadMutex.RUnlock() return s.info } -func (s *SocialBase) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - info, err := CreateOAuthInfoFromKeyValues(settings.Settings) - if err != nil { - return fmt.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) - } - - s.infoMutex.Lock() - defer s.infoMutex.Unlock() - - s.info = info - - return nil -} - func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) { if s.info.RoleAttributePath == "" { if s.info.RoleAttributeStrict { @@ -145,9 +129,9 @@ func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType, // defaultRole returns the default role for the user based on the autoAssignOrgRole setting // if legacy is enabled "" is returned indicating the previous role assignment is used. func (s *SocialBase) defaultRole() org.RoleType { - if s.autoAssignOrgRole != "" { + if s.cfg.AutoAssignOrgRole != "" { s.log.Debug("No role found, returning default.") - return org.RoleType(s.autoAssignOrgRole) + return org.RoleType(s.cfg.AutoAssignOrgRole) } // should never happen diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index d5d42591d06..a6bc4d451cd 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -140,18 +140,8 @@ func ProvideService( } for name := range socialService.GetOAuthProviders() { - oauthCfg := socialService.GetOAuthInfoProvider(name) - if oauthCfg != nil && oauthCfg.Enabled { - clientName := authn.ClientWithPrefix(name) - - connector, errConnector := socialService.GetConnector(name) - httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name) - if errConnector != nil || errHTTPClient != nil { - s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient)) - } else { - s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService)) - } - } + clientName := authn.ClientWithPrefix(name) + s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthTokenService, socialService)) } // FIXME (jguer): move to User package diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index e0a4bc676dc..0c6d7099e5a 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -8,7 +8,6 @@ import ( "encoding/hex" "errors" "fmt" - "net/http" "net/url" "strings" @@ -40,6 +39,9 @@ const ( ) var ( + errOAuthClientDisabled = errutil.BadRequest("auth.oauth.disabled", errutil.WithPublicMessage("OAuth client is disabled")) + errOAuthInternal = errutil.Internal("auth.oauth.internal", errutil.WithPublicMessage("An internal error occurred in the OAuth client")) + errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred")) errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie")) @@ -62,24 +64,25 @@ var _ authn.LogoutClient = new(OAuth) var _ authn.RedirectClient = new(OAuth) func ProvideOAuth( - name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo, - connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService, + name string, cfg *setting.Cfg, oauthService oauthtoken.OAuthTokenService, + socialService social.Service, ) *OAuth { + providerName := strings.TrimPrefix(name, "auth.client.") return &OAuth{ - name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")), - log.New(name), cfg, oauthCfg, connector, httpClient, oauthService, + name, fmt.Sprintf("oauth_%s", providerName), providerName, + log.New(name), cfg, oauthService, socialService, } } type OAuth struct { name string moduleName string + providerName string log log.Logger cfg *setting.Cfg - oauthCfg *social.OAuthInfo - connector social.SocialConnector - httpClient *http.Client - oauthService oauthtoken.OAuthTokenService + + oauthService oauthtoken.OAuthTokenService + socialService social.Service } func (c *OAuth) Name() string { @@ -88,6 +91,12 @@ func (c *OAuth) Name() string { func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) { r.SetMeta(authn.MetaKeyAuthModule, c.moduleName) + + oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName) + if !oauthCfg.Enabled { + return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName) + } + // get hashed state stored in cookie stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName) if err != nil { @@ -99,7 +108,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden } // get state returned by the idp and hash it - stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret) + stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, oauthCfg.ClientSecret) // compare the state returned by idp against the one we stored in cookie if stateQuery != stateCookie.Value { return nil, errOAuthInvalidState.Errorf("provided state did not match stored state") @@ -107,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden var opts []oauth2.AuthCodeOption // if pkce is enabled for client validate we have the cookie and set it as url param - if c.oauthCfg.UsePKCE { + if oauthCfg.UsePKCE { pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName) if err != nil { return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err) @@ -115,15 +124,21 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden opts = append(opts, oauth2.VerifierOption(pkceCookie.Value)) } - clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) + connector, errConnector := c.socialService.GetConnector(c.providerName) + httpClient, errHTTPClient := c.socialService.GetOAuthHttpClient(c.providerName) + if errConnector != nil || errHTTPClient != nil { + return nil, errOAuthInternal.Errorf("failed to get %s oauth client: %w", c.name, errors.Join(errConnector, errHTTPClient)) + } + + clientCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) // exchange auth code to a valid token - token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...) + token, err := connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...) if err != nil { return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err) } token.TokenType = "Bearer" - userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token) + userInfo, err := connector.UserInfo(ctx, connector.Client(clientCtx, token), token) if err != nil { var sErr *connectors.SocialError if errors.As(err, &sErr) { @@ -136,7 +151,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided") } - if !c.connector.IsEmailAllowed(userInfo.Email) { + if !connector.IsEmailAllowed(userInfo.Email) { return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed") } @@ -167,7 +182,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden SyncTeams: true, FetchSyncedUser: true, SyncPermissions: true, - AllowSignUp: c.connector.IsSignupAllowed(), + AllowSignUp: connector.IsSignupAllowed(), // skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed SyncOrgRoles: len(orgRoles) > 0, LookUpParams: lookupParams, @@ -178,12 +193,17 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) { var opts []oauth2.AuthCodeOption - if c.oauthCfg.HostedDomain != "" { - opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain)) + oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName) + if !oauthCfg.Enabled { + return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName) + } + + if oauthCfg.HostedDomain != "" { + opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, oauthCfg.HostedDomain)) } var plainPKCE string - if c.oauthCfg.UsePKCE { + if oauthCfg.UsePKCE { verifier, err := genPKCECodeVerifier() if err != nil { return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err) @@ -193,13 +213,18 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir opts = append(opts, oauth2.S256ChallengeOption(plainPKCE)) } - state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret) + state, hashedSate, err := genOAuthState(c.cfg.SecretKey, oauthCfg.ClientSecret) if err != nil { return nil, errOAuthGenState.Errorf("failed to generate state: %w", err) } + connector, err := c.socialService.GetConnector(c.providerName) + if err != nil { + return nil, errOAuthInternal.Errorf("failed to get %s oauth connector: %w", c.name, err) + } + return &authn.Redirect{ - URL: c.connector.AuthCodeURL(state, opts...), + URL: connector.AuthCodeURL(state, opts...), Extra: map[string]string{ authn.KeyOAuthState: hashedSate, authn.KeyOAuthPKCE: plainPKCE, @@ -215,19 +240,25 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err) } - redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg) - if redirctURL == "" { + oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName) + if !oauthCfg.Enabled { + c.log.FromContext(ctx).Debug("OAuth client is disabled") + return nil, false + } + + redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg) + if redirectURL == "" { c.log.FromContext(ctx).Debug("No signout redirect url configured") return nil, false } - if isOICDLogout(redirctURL) && token != nil && token.Valid() { + if isOICDLogout(redirectURL) && token != nil && token.Valid() { if idToken, ok := token.Extra("id_token").(string); ok { - redirctURL = withIDTokenHint(redirctURL, idToken) + redirectURL = withIDTokenHint(redirectURL, idToken) } } - return &authn.Redirect{URL: redirctURL}, true + return &authn.Redirect{URL: redirectURL}, true } // genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string. diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index b9d5cf57e0c..04f0b3d38ce 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" @@ -46,17 +47,23 @@ func TestOAuth_Authenticate(t *testing.T) { { desc: "should return error when missing state cookie", req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, - oauthCfg: &social.OAuthInfo{}, + oauthCfg: &social.OAuthInfo{Enabled: true}, expectedErr: errOAuthMissingState, }, { desc: "should return error when state cookie is present but don't have a value", req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, - oauthCfg: &social.OAuthInfo{}, + oauthCfg: &social.OAuthInfo{Enabled: true}, addStateCookie: true, stateCookieValue: "", expectedErr: errOAuthMissingState, }, + { + desc: "should return error when the client is not enabled", + req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, + oauthCfg: &social.OAuthInfo{Enabled: false}, + expectedErr: errOAuthClientDisabled, + }, { desc: "should return error when state from ipd does not match stored state", req: &authn.Request{HTTPRequest: &http.Request{ @@ -64,7 +71,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-other-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, stateCookieValue: "some-state", expectedErr: errOAuthInvalidState, @@ -76,7 +83,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, stateCookieValue: "some-state", expectedErr: errOAuthMissingPKCE, @@ -88,7 +95,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, stateCookieValue: "some-state", addPKCECookie: true, @@ -103,7 +110,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, stateCookieValue: "some-state", addPKCECookie: true, @@ -119,7 +126,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, stateCookieValue: "some-state", addPKCECookie: true, @@ -157,7 +164,7 @@ func TestOAuth_Authenticate(t *testing.T) { URL: mustParseURL("http://grafana.com/?state=some-state"), }, }, - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, allowInsecureTakeover: true, addStateCookie: true, stateCookieValue: "some-state", @@ -211,12 +218,18 @@ func TestOAuth_Authenticate(t *testing.T) { tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue}) } - c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{ - ExpectedUserInfo: tt.userInfo, - ExpectedToken: &oauth2.Token{}, - ExpectedIsSignupAllowed: true, - ExpectedIsEmailAllowed: tt.isEmailAllowed, - }, nil, nil) + fakeSocialSvc := &socialtest.FakeSocialService{ + ExpectedAuthInfoProvider: tt.oauthCfg, + ExpectedConnector: fakeConnector{ + ExpectedUserInfo: tt.userInfo, + ExpectedToken: &oauth2.Token{}, + ExpectedIsSignupAllowed: true, + ExpectedIsEmailAllowed: tt.isEmailAllowed, + }, + } + + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc) + identity, err := c.Authenticate(context.Background(), tt.req) assert.ErrorIs(t, err, tt.expectedErr) @@ -256,21 +269,27 @@ func TestOAuth_RedirectURL(t *testing.T) { tests := []testCase{ { desc: "should generate redirect url and state", - oauthCfg: &social.OAuthInfo{}, + oauthCfg: &social.OAuthInfo{Enabled: true}, authCodeUrlCalled: true, }, { desc: "should generate redirect url with hosted domain option if configured", - oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"}, + oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true}, numCallOptions: 1, authCodeUrlCalled: true, }, { desc: "should generate redirect url with pkce if configured", - oauthCfg: &social.OAuthInfo{UsePKCE: true}, + oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, numCallOptions: 1, authCodeUrlCalled: true, }, + { + desc: "should return error if the client is not enabled", + oauthCfg: &social.OAuthInfo{Enabled: false}, + authCodeUrlCalled: false, + expectedErr: errOAuthClientDisabled, + }, } for _, tt := range tests { @@ -279,13 +298,18 @@ func TestOAuth_RedirectURL(t *testing.T) { authCodeUrlCalled = false ) - c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{ - AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string { - authCodeUrlCalled = true - require.Len(t, opts, tt.numCallOptions) - return "" + fakeSocialSvc := &socialtest.FakeSocialService{ + ExpectedAuthInfoProvider: tt.oauthCfg, + ExpectedConnector: mockConnector{ + AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string { + authCodeUrlCalled = true + require.Len(t, opts, tt.numCallOptions) + return "" + }, }, - }, nil, nil) + } + + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), nil, fakeSocialSvc) redirect, err := c.RedirectURL(context.Background(), nil) assert.ErrorIs(t, err, tt.expectedErr) @@ -321,12 +345,17 @@ func TestOAuth_Logout(t *testing.T) { cfg: &setting.Cfg{}, oauthCfg: &social.OAuthInfo{}, }, + { + desc: "should not return redirect url when client is not enabled", + cfg: &setting.Cfg{}, + oauthCfg: &social.OAuthInfo{Enabled: false}, + }, { desc: "should return redirect url for globably configured redirect url", cfg: &setting.Cfg{ SignoutRedirectUrl: "http://idp.com/logout", }, - oauthCfg: &social.OAuthInfo{}, + oauthCfg: &social.OAuthInfo{Enabled: true}, expectedURL: "http://idp.com/logout", expectedOK: true, }, @@ -334,6 +363,7 @@ func TestOAuth_Logout(t *testing.T) { desc: "should return redirect url for client configured redirect url", cfg: &setting.Cfg{}, oauthCfg: &social.OAuthInfo{ + Enabled: true, SignoutRedirectUrl: "http://idp.com/logout", }, expectedURL: "http://idp.com/logout", @@ -345,6 +375,7 @@ func TestOAuth_Logout(t *testing.T) { SignoutRedirectUrl: "http://idp.com/logout", }, oauthCfg: &social.OAuthInfo{ + Enabled: true, SignoutRedirectUrl: "http://idp-2.com/logout", }, expectedURL: "http://idp-2.com/logout", @@ -354,6 +385,7 @@ func TestOAuth_Logout(t *testing.T) { desc: "should add id token hint if oicd logout is configured and token is valid", cfg: &setting.Cfg{}, oauthCfg: &social.OAuthInfo{ + Enabled: true, SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin", }, expectedURL: "http://idp.com/logout", @@ -387,7 +419,10 @@ func TestOAuth_Logout(t *testing.T) { }, } - c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService) + fakeSocialSvc := &socialtest.FakeSocialService{ + ExpectedAuthInfoProvider: tt.oauthCfg, + } + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc) redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{}) diff --git a/pkg/services/ssosettings/api/api.go b/pkg/services/ssosettings/api/api.go index e025c2786ea..d2dc4ba4ac1 100644 --- a/pkg/services/ssosettings/api/api.go +++ b/pkg/services/ssosettings/api/api.go @@ -183,7 +183,7 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp settings.Provider = key - err := api.SSOSettingsService.Upsert(c.Req.Context(), settings) + err := api.SSOSettingsService.Upsert(c.Req.Context(), &settings) if err != nil { return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err) } diff --git a/pkg/services/ssosettings/api/api_test.go b/pkg/services/ssosettings/api/api_test.go index 96c94f0e5b4..1da658af811 100644 --- a/pkg/services/ssosettings/api/api_test.go +++ b/pkg/services/ssosettings/api/api_test.go @@ -134,7 +134,7 @@ func TestSSOSettingsAPI_Update(t *testing.T) { service := ssosettingstests.NewMockService(t) if tt.expectedServiceCall { - service.On("Upsert", mock.Anything, settings).Return(tt.expectedError).Once() + service.On("Upsert", mock.Anything, &settings).Return(tt.expectedError).Once() } server := setupTests(t, service) diff --git a/pkg/services/ssosettings/database/database.go b/pkg/services/ssosettings/database/database.go index ea3782a8bbf..320c248a07d 100644 --- a/pkg/services/ssosettings/database/database.go +++ b/pkg/services/ssosettings/database/database.go @@ -84,7 +84,7 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err return result, nil } -func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error { +func (s *SSOSettingsStore) Upsert(ctx context.Context, settings *models.SSOSettings) error { if settings.Provider == "" { return ssosettings.ErrNotFound } @@ -110,13 +110,10 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin } _, err = sess.UseBool(isDeletedColumn).Update(updated, existing) } else { - _, err = sess.Insert(&models.SSOSettings{ - ID: uuid.New().String(), - Provider: settings.Provider, - Settings: settings.Settings, - Created: now, - Updated: now, - }) + settings.ID = uuid.New().String() + settings.Created = now + settings.Updated = now + _, err = sess.Insert(settings) } return err diff --git a/pkg/services/ssosettings/database/database_test.go b/pkg/services/ssosettings/database/database_test.go index bb51ec46bbc..df62c3f4242 100644 --- a/pkg/services/ssosettings/database/database_test.go +++ b/pkg/services/ssosettings/database/database_test.go @@ -105,7 +105,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { }, } - err := ssoSettingsStore.Upsert(context.Background(), settings) + err := ssoSettingsStore.Upsert(context.Background(), &settings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false) @@ -143,7 +143,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { "client_secret": "this-is-a-new-secret", }, } - err = ssoSettingsStore.Upsert(context.Background(), newSettings) + err = ssoSettingsStore.Upsert(context.Background(), &newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, provider, false) @@ -181,7 +181,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { }, } - err = ssoSettingsStore.Upsert(context.Background(), newSettings) + err = ssoSettingsStore.Upsert(context.Background(), &newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, provider, false) @@ -217,7 +217,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { "client_secret": "this-is-my-new-secret", }, } - err = ssoSettingsStore.Upsert(context.Background(), newSettings) + err = ssoSettingsStore.Upsert(context.Background(), &newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false) @@ -254,7 +254,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { }, } - err = ssoSettingsStore.Upsert(context.Background(), settings) + err = ssoSettingsStore.Upsert(context.Background(), &settings) require.Error(t, err) require.ErrorIs(t, err, ssosettings.ErrNotFound) }) diff --git a/pkg/services/ssosettings/errors.go b/pkg/services/ssosettings/errors.go index d01aa046470..83563ff28de 100644 --- a/pkg/services/ssosettings/errors.go +++ b/pkg/services/ssosettings/errors.go @@ -9,7 +9,7 @@ import ( var ( ErrNotFound = errors.New("not found") - ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("provider is invalid")) - ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("settings field is invalid")) - ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("settings.clientId cannot be empty")) + ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("Provider is invalid")) + ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("Settings field is invalid")) + ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("ClientId cannot be empty")) ) diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go index 3a144fc7e5a..f31d30b4fde 100644 --- a/pkg/services/ssosettings/ssosettings.go +++ b/pkg/services/ssosettings/ssosettings.go @@ -28,7 +28,7 @@ type Service interface { // GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error) // Upsert creates or updates the SSO settings for a given provider - Upsert(ctx context.Context, settings models.SSOSettings) error + Upsert(ctx context.Context, settings *models.SSOSettings) error // Delete deletes the SSO settings for a given provider (soft delete) Delete(ctx context.Context, provider string) error // Patch updates the specified SSO settings (key-value pairs) for a given provider @@ -62,6 +62,6 @@ type FallbackStrategy interface { type Store interface { Get(ctx context.Context, provider string) (*models.SSOSettings, error) List(ctx context.Context) ([]*models.SSOSettings, error) - Upsert(ctx context.Context, settings models.SSOSettings) error + Upsert(ctx context.Context, settings *models.SSOSettings) error Delete(ctx context.Context, provider string) error } diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go index 0a02a781757..6ef210ff71b 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -147,7 +147,7 @@ func (s *SSOSettingsService) ListWithRedactedSecrets(ctx context.Context) ([]*mo return storeSettings, nil } -func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error { +func (s *SSOSettingsService) Upsert(ctx context.Context, settings *models.SSOSettings) error { if !isProviderConfigurable(settings.Provider) { return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider) } @@ -157,7 +157,7 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider) } - err := social.Validate(ctx, settings) + err := social.Validate(ctx, *settings) if err != nil { return err } @@ -167,6 +167,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett return err } + secrets := collectSecrets(settings, storedSettings) + settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings) if err != nil { return err @@ -178,7 +180,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett } go func() { - err = social.Reload(context.Background(), settings) + settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets) + err = social.Reload(context.Background(), *settings) if err != nil { s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err) } @@ -249,7 +252,7 @@ func (s *SSOSettingsService) getFallbackStrategyFor(provider string) (ssosetting func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) { result := make(map[string]any) for k, v := range settings { - if isSecret(k) { + if isSecret(k) && v != "" { strValue, ok := v.(string) if !ok { return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v) @@ -315,20 +318,24 @@ func (s *SSOSettingsService) doReload(ctx context.Context) { // mergeSSOSettings merges the settings from the database with the system settings // Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings) -// and the rest of the settings are loaded from the system settings +// and the rest of the settings have to be loaded from the system settings func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings { if dbSettings == nil { s.logger.Debug("No SSO Settings found in the database, using system settings") return systemSettings } - s.logger.Debug("Merging SSO Settings", "dbSettings", dbSettings.Settings, "systemSettings", systemSettings.Settings) + s.logger.Debug("Merging SSO Settings", "dbSettings", removeSecrets(dbSettings.Settings), "systemSettings", removeSecrets(systemSettings.Settings)) - finalSettings := mergeSettings(dbSettings.Settings, systemSettings.Settings) + result := &models.SSOSettings{ + Provider: dbSettings.Provider, + Source: dbSettings.Source, + Settings: mergeSettings(dbSettings.Settings, systemSettings.Settings), + Created: dbSettings.Created, + Updated: dbSettings.Updated, + } - dbSettings.Settings = finalSettings - - return dbSettings + return result } func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) { @@ -358,6 +365,22 @@ func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[st return settings, nil } +// removeSecrets removes all the secrets from the map and replaces them with a redacted password +// and returns a new map +func removeSecrets(settings map[string]any) map[string]any { + result := make(map[string]any) + for k, v := range settings { + if isSecret(k) { + result[k] = setting.RedactedPassword + continue + } + result[k] = v + } + return result +} + +// mergeSettings merges two maps in a way that the values from the first map are preserved +// and the values from the second map are added only if they don't exist in the first map func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any { settings := make(map[string]any) @@ -374,6 +397,32 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any return settings } +// collectSecrets collects all the secrets from the request and the currently stored settings +// and returns a new map +func collectSecrets(settings *models.SSOSettings, storedSettings *models.SSOSettings) map[string]any { + secrets := map[string]any{} + for k, v := range settings.Settings { + if isSecret(k) { + if isNewSecretValue(v.(string)) { + secrets[k] = v.(string) // use the new value + continue + } + secrets[k] = storedSettings.Settings[k] // keep the currently stored value + } + } + return secrets +} + +func overrideMaps(maps ...map[string]any) map[string]any { + result := make(map[string]any) + for _, m := range maps { + for k, v := range m { + result[k] = v + } + } + return result +} + func isSecret(fieldName string) bool { secretFieldPatterns := []string{"secret"} diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go index ec95ffae1cb..a0c1aca5778 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service_test.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -5,7 +5,10 @@ import ( "encoding/base64" "errors" "fmt" + "maps" + "sync" "testing" + "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -772,16 +775,50 @@ func TestSSOSettingsService_Upsert(t *testing.T) { }, IsDeleted: false, } + var wg sync.WaitGroup + wg.Add(1) reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) - reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe() + reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool { + wg.Done() + return settings.Provider == provider && + settings.ID == "someid" && + maps.Equal(settings.Settings, map[string]any{ + "client_id": "client-id", + "client_secret": "client-secret", + "enabled": true, + }) + })).Return(nil).Once() env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() + env.secrets.On("Decrypt", mock.Anything, []byte("encrypted-current-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once() - err := env.service.Upsert(context.Background(), settings) + env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error { + currentTime := time.Now() + settings.ID = "someid" + settings.Created = currentTime + settings.Updated = currentTime + + env.store.ActualSSOSettings = *settings + return nil + } + + env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) { + return &models.SSOSettings{ + ID: "someid", + Provider: provider, + Settings: map[string]any{ + "client_secret": base64.RawStdEncoding.EncodeToString([]byte("encrypted-current-client-secret")), + }, + }, nil + } + err := env.service.Upsert(context.Background(), &settings) require.NoError(t, err) + // Wait for the goroutine first to assert the Reload call + wg.Wait() + settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) require.EqualValues(t, settings, env.store.ActualSSOSettings) }) @@ -790,7 +827,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) { env := setupTestEnv(t) provider := social.GrafanaComProviderName - settings := models.SSOSettings{ + settings := &models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", @@ -811,7 +848,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName - settings := models.SSOSettings{ + settings := &models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", @@ -847,14 +884,14 @@ func TestSSOSettingsService_Upsert(t *testing.T) { reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed")) env.reloadables[provider] = reloadable - err := env.service.Upsert(context.Background(), settings) + err := env.service.Upsert(context.Background(), &settings) require.Error(t, err) }) t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) { env := setupTestEnv(t) - settings := models.SSOSettings{ + settings := &models.SSOSettings{ Provider: social.AzureADProviderName, Settings: map[string]any{ "client_id": "client-id", @@ -889,7 +926,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) { env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once() - err := env.service.Upsert(context.Background(), settings) + err := env.service.Upsert(context.Background(), &settings) require.Error(t, err) }) @@ -921,7 +958,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) { env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once() - err := env.service.Upsert(context.Background(), settings) + err := env.service.Upsert(context.Background(), &settings) require.NoError(t, err) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret")) @@ -950,11 +987,11 @@ func TestSSOSettingsService_Upsert(t *testing.T) { return &models.SSOSettings{}, nil } - env.store.UpsertFn = func(ctx context.Context, settings models.SSOSettings) error { + env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error { return errors.New("failed to upsert settings") } - err := env.service.Upsert(context.Background(), settings) + err := env.service.Upsert(context.Background(), &settings) require.Error(t, err) }) @@ -978,7 +1015,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) { env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() - err := env.service.Upsert(context.Background(), settings) + err := env.service.Upsert(context.Background(), &settings) require.NoError(t, err) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) @@ -1098,6 +1135,22 @@ func TestSSOSettingsService_decryptSecrets(t *testing.T) { "other_secret": "decrypted-other-secret", }, }, + { + name: "should not decrypt when a secret is empty", + setup: func(env testEnv) { + env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once() + }, + settings: map[string]any{ + "enabled": true, + "client_secret": "", + "other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")), + }, + want: map[string]any{ + "enabled": true, + "client_secret": "", + "other_secret": "decrypted-other-secret", + }, + }, { name: "should return an error if data is not a string", settings: map[string]any{ diff --git a/pkg/services/ssosettings/ssosettingstests/reloadable_mock.go b/pkg/services/ssosettings/ssosettingstests/reloadable_mock.go index e4c34dad862..a6c0526cfa4 100644 --- a/pkg/services/ssosettings/ssosettingstests/reloadable_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/reloadable_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package ssosettingstests diff --git a/pkg/services/ssosettings/ssosettingstests/service_mock.go b/pkg/services/ssosettings/ssosettingstests/service_mock.go index d950697bb39..42ef1091fb9 100644 --- a/pkg/services/ssosettings/ssosettingstests/service_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/service_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package ssosettingstests @@ -183,7 +183,7 @@ func (_m *MockService) Reload(ctx context.Context, provider string) { } // Upsert provides a mock function with given fields: ctx, settings -func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error { +func (_m *MockService) Upsert(ctx context.Context, settings *models.SSOSettings) error { ret := _m.Called(ctx, settings) if len(ret) == 0 { @@ -191,7 +191,7 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok { r0 = rf(ctx, settings) } else { r0 = ret.Error(0) diff --git a/pkg/services/ssosettings/ssosettingstests/store_fake.go b/pkg/services/ssosettings/ssosettingstests/store_fake.go index d9e3d2539dc..016263ce8b5 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_fake.go +++ b/pkg/services/ssosettings/ssosettingstests/store_fake.go @@ -17,7 +17,7 @@ type FakeStore struct { ActualSSOSettings models.SSOSettings GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error) - UpsertFn func(ctx context.Context, settings models.SSOSettings) error + UpsertFn func(ctx context.Context, settings *models.SSOSettings) error } func NewFakeStore() *FakeStore { @@ -35,12 +35,12 @@ func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) { return f.ExpectedSSOSettings, f.ExpectedError } -func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error { +func (f *FakeStore) Upsert(ctx context.Context, settings *models.SSOSettings) error { if f.UpsertFn != nil { return f.UpsertFn(ctx, settings) } - f.ActualSSOSettings = settings + f.ActualSSOSettings = *settings return f.ExpectedError } diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go index 0db6238da07..009660cd3e6 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.39.2. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package ssosettingstests @@ -93,7 +93,7 @@ func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) { } // Upsert provides a mock function with given fields: ctx, settings -func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error { +func (_m *MockStore) Upsert(ctx context.Context, settings *models.SSOSettings) error { ret := _m.Called(ctx, settings) if len(ret) == 0 { @@ -101,7 +101,7 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok { r0 = rf(ctx, settings) } else { r0 = ret.Error(0) diff --git a/pkg/services/ssosettings/strategies/oauth_strategy.go b/pkg/services/ssosettings/strategies/oauth_strategy.go index 2e036190686..3dfe79a1bc3 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy.go @@ -2,6 +2,7 @@ package strategies import ( "context" + "maps" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/ssosettings" @@ -31,7 +32,10 @@ func (s *OAuthStrategy) IsMatch(provider string) bool { } func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) { - return s.settingsByProvider[provider], nil + providerConfig := s.settingsByProvider[provider] + result := make(map[string]any, len(providerConfig)) + maps.Copy(result, providerConfig) + return result, nil } func (s *OAuthStrategy) loadAllSettings() {