AuthN: Support reloading SSO config after the sso settings have changed (#80734)

* Add AuthNSvc reload handling

* Working, need to add test

* Remove commented out code

* Add Reload implementation to connectors

* Align and add tests, refactor

* Add more tests, linting

* Add extra checks + tests to oauth client

* Clean up based on reviews

* Move config instantiation into newSocialBase

* Use specific error
This commit is contained in:
Misi
2024-01-22 14:54:48 +01:00
committed by GitHub
parent 1f4a520b9d
commit 20bb0a3ab1
31 changed files with 889 additions and 217 deletions

View File

@@ -73,16 +73,15 @@ type keySetJWKS struct {
} }
func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD { func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD {
config := createOAuthConfig(info, cfg, social.AzureADProviderName)
provider := &SocialAzureAD{ provider := &SocialAzureAD{
SocialBase: newSocialBase(social.AzureADProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.AzureADProviderName, info, features, cfg),
cache: cache, cache: cache,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false), forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false),
} }
if info.UseRefreshToken { if info.UseRefreshToken {
appendUniqueScope(config, social.OfflineAccessScope) appendUniqueScope(provider.Config, social.OfflineAccessScope)
} }
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -164,6 +163,27 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
}, nil }, nil
} }
func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.AzureADProviderName, newInfo, s.features, s.cfg)
if newInfo.UseRefreshToken {
appendUniqueScope(s.Config, social.OfflineAccessScope)
}
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
s.forceUseGraphAPI = MustBool(newInfo.Extra[forceUseGraphAPIKey], false)
return nil
}
func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error {
info, err := CreateOAuthInfoFromKeyValues(settings.Settings) info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil { if err != nil {

View File

@@ -1048,11 +1048,12 @@ func TestSocialAzureAD_Validate(t *testing.T) {
func TestSocialAzureAD_Reload(t *testing.T) { func TestSocialAzureAD_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -1073,6 +1074,14 @@ func TestSocialAzureAD_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/azuread",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -1092,6 +1101,11 @@ func TestSocialAzureAD_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/azuread",
},
}, },
} }
@@ -1102,10 +1116,65 @@ func TestSocialAzureAD_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
expectedForceUseGraphApi bool
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
"force_use_graph_api": "true",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
expectedForceUseGraphApi: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewAzureADProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedForceUseGraphApi, s.forceUseGraphAPI)
}) })
} }
} }

View File

@@ -46,9 +46,8 @@ type SocialGenericOAuth struct {
} }
func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth { func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth {
config := createOAuthConfig(info, cfg, social.GenericOAuthProviderName)
provider := &SocialGenericOAuth{ provider := &SocialGenericOAuth{
SocialBase: newSocialBase(social.GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.GenericOAuthProviderName, info, features, cfg),
teamsUrl: info.TeamsUrl, teamsUrl: info.TeamsUrl,
emailAttributeName: info.EmailAttributeName, emailAttributeName: info.EmailAttributeName,
emailAttributePath: info.EmailAttributePath, emailAttributePath: info.EmailAttributePath,
@@ -84,6 +83,31 @@ func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SS
return nil return nil
} }
func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GenericOAuthProviderName, newInfo, s.features, s.cfg)
s.teamsUrl = newInfo.TeamsUrl
s.emailAttributeName = newInfo.EmailAttributeName
s.emailAttributePath = newInfo.EmailAttributePath
s.nameAttributePath = newInfo.Extra[nameAttributePathKey]
s.groupsAttributePath = newInfo.GroupsAttributePath
s.loginAttributePath = newInfo.Extra[loginAttributePathKey]
s.idTokenAttributeName = newInfo.Extra[idTokenAttributeNameKey]
s.teamIdsAttributePath = newInfo.TeamIdsAttributePath
s.teamIds = util.SplitString(newInfo.Extra[teamIdsKey])
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil
}
// TODOD: remove this in the next PR and use the isGroupMember from social.go // TODOD: remove this in the next PR and use the isGroupMember from social.go
func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool { func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
info := s.GetOAuthInfo() info := s.GetOAuthInfo()

View File

@@ -976,11 +976,12 @@ func TestSocialGenericOAuth_Validate(t *testing.T) {
func TestSocialGenericOAuth_Reload(t *testing.T) { func TestSocialGenericOAuth_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -1001,6 +1002,14 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/generic_oauth",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -1020,6 +1029,11 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/generic_oauth",
},
}, },
} }
@@ -1030,10 +1044,116 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestGenericOAuth_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedTeamsUrl string
expectedEmailAttributeName string
expectedEmailAttributePath string
expectedNameAttributePath string
expectedGroupsAttributePath string
expectedLoginAttributePath string
expectedIdTokenAttributeName string
expectedTeamIdsAttributePath string
expectedTeamIds []string
expectedAllowedOrganizations []string
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
TeamsUrl: "https://host/users",
EmailAttributePath: "email-attr-path",
EmailAttributeName: "email-attr-name",
GroupsAttributePath: "groups-attr-path",
TeamIdsAttributePath: "team-ids-attr-path",
Extra: map[string]string{
teamIdsKey: "team1",
allowedOrganizationsKey: "org1",
loginAttributePathKey: "login-attr-path",
idTokenAttributeNameKey: "id-token-attr-name",
nameAttributePathKey: "name-attr-path",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"teams_url": "https://host/v2/users",
"email_attribute_path": "new-email-attr-path",
"email_attribute_name": "new-email-attr-name",
"groups_attribute_path": "new-group-attr-path",
"team_ids_attribute_path": "new-team-ids-attr-path",
teamIdsKey: "team1,team2",
allowedOrganizationsKey: "org1,org2",
loginAttributePathKey: "new-login-attr-path",
idTokenAttributeNameKey: "new-id-token-attr-name",
nameAttributePathKey: "new-name-attr-path",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
TeamsUrl: "https://host/v2/users",
EmailAttributePath: "new-email-attr-path",
EmailAttributeName: "new-email-attr-name",
GroupsAttributePath: "new-group-attr-path",
TeamIdsAttributePath: "new-team-ids-attr-path",
Extra: map[string]string{
teamIdsKey: "team1,team2",
allowedOrganizationsKey: "org1,org2",
loginAttributePathKey: "new-login-attr-path",
idTokenAttributeNameKey: "new-id-token-attr-name",
nameAttributePathKey: "new-name-attr-path",
},
},
expectedTeamsUrl: "https://host/v2/users",
expectedEmailAttributeName: "new-email-attr-name",
expectedEmailAttributePath: "new-email-attr-path",
expectedGroupsAttributePath: "new-group-attr-path",
expectedTeamIdsAttributePath: "new-team-ids-attr-path",
expectedTeamIds: []string{"team1", "team2"},
expectedAllowedOrganizations: []string{"org1", "org2"},
expectedLoginAttributePath: "new-login-attr-path",
expectedIdTokenAttributeName: "new-id-token-attr-name",
expectedNameAttributePath: "new-name-attr-path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewGenericOAuthProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedTeamsUrl, s.teamsUrl)
require.EqualValues(t, tc.expectedEmailAttributeName, s.emailAttributeName)
require.EqualValues(t, tc.expectedEmailAttributePath, s.emailAttributePath)
require.EqualValues(t, tc.expectedGroupsAttributePath, s.groupsAttributePath)
require.EqualValues(t, tc.expectedTeamIdsAttributePath, s.teamIdsAttributePath)
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedLoginAttributePath, s.loginAttributePath)
require.EqualValues(t, tc.expectedIdTokenAttributeName, s.idTokenAttributeName)
require.EqualValues(t, tc.expectedNameAttributePath, s.nameAttributePath)
}) })
} }
} }

View File

@@ -57,9 +57,8 @@ func NewGitHubProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings sso
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey]) teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted) teamIds := mustInts(teamIdsSplitted)
config := createOAuthConfig(info, cfg, social.GitHubProviderName)
provider := &SocialGithub{ provider := &SocialGithub{
SocialBase: newSocialBase(social.GitHubProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.GitHubProviderName, info, features, cfg),
teamIds: teamIds, teamIds: teamIds,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
} }
@@ -86,7 +85,37 @@ func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSetti
return err return err
} }
// add specific validation rules for Github teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted)
if len(teamIdsSplitted) != len(teamIds) {
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
return ssosettings.ErrInvalidSettings.Errorf("Failed to parse team ids. Team ids must be a list of numbers.")
}
return nil
}
func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
teamIdsSplitted := util.SplitString(newInfo.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted)
if len(teamIdsSplitted) != len(teamIds) {
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GitHubProviderName, newInfo, s.features, s.cfg)
s.teamIds = teamIds
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil return nil
} }

View File

@@ -402,11 +402,12 @@ func TestSocialGitHub_Validate(t *testing.T) {
func TestSocialGitHub_Reload(t *testing.T) { func TestSocialGitHub_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -427,6 +428,14 @@ func TestSocialGitHub_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/github",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -446,6 +455,11 @@ func TestSocialGitHub_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/github",
},
}, },
} }
@@ -456,10 +470,67 @@ func TestSocialGitHub_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestGitHub_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
expectedTeamIds []int
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
"team_ids": "",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
"team_ids": "123,456",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
AuthStyle: "inheader",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
expectedTeamIds: []int{123, 456},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewGitHubProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
}) })
} }
} }

View File

@@ -53,9 +53,8 @@ type userData struct {
} }
func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab { func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab {
config := createOAuthConfig(info, cfg, social.GitlabProviderName)
provider := &SocialGitlab{ provider := &SocialGitlab{
SocialBase: newSocialBase(social.GitlabProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.GitlabProviderName, info, features, cfg),
} }
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -76,7 +75,19 @@ func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSetti
return err return err
} }
// add specific validation rules for Gitlab return nil
}
func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GitlabProviderName, newInfo, s.features, s.cfg)
return nil return nil
} }

View File

@@ -164,7 +164,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
for _, test := range tests { for _, test := range tests {
provider.info.RoleAttributePath = test.RoleAttributePath provider.info.RoleAttributePath = test.RoleAttributePath
provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin
provider.autoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole) provider.cfg.AutoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole)
provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict
provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync
@@ -520,11 +520,12 @@ func TestSocialGitlab_Validate(t *testing.T) {
func TestSocialGitlab_Reload(t *testing.T) { func TestSocialGitlab_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -545,6 +546,14 @@ func TestSocialGitlab_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/gitlab",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -564,6 +573,11 @@ func TestSocialGitlab_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/gitlab",
},
}, },
} }
@@ -574,10 +588,12 @@ func TestSocialGitlab_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
}) })
} }
} }

View File

@@ -39,9 +39,8 @@ type googleUserData struct {
} }
func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle { func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle {
config := createOAuthConfig(info, cfg, social.GoogleProviderName)
provider := &SocialGoogle{ provider := &SocialGoogle{
SocialBase: newSocialBase(social.GoogleProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.GoogleProviderName, info, features, cfg),
} }
if strings.HasPrefix(info.ApiUrl, legacyAPIURL) { if strings.HasPrefix(info.ApiUrl, legacyAPIURL) {
@@ -71,6 +70,24 @@ func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSetti
return nil return nil
} }
func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
if strings.HasPrefix(newInfo.ApiUrl, legacyAPIURL) {
s.log.Warn("Using legacy Google API URL, please update your configuration")
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GoogleProviderName, newInfo, s.features, s.cfg)
return nil
}
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
info := s.GetOAuthInfo() info := s.GetOAuthInfo()
@@ -225,7 +242,7 @@ type googleGroupResp struct {
} }
func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) { func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) {
s.log.Debug("Retrieving groups", "scopes", s.SocialBase.Config.Scopes) s.log.Debug("Retrieving groups", "scopes", s.Config.Scopes)
if !slices.Contains(s.Scopes, googleIAMScope) { if !slices.Contains(s.Scopes, googleIAMScope) {
return nil, nil return nil, nil
} }

View File

@@ -725,11 +725,12 @@ func TestSocialGoogle_Validate(t *testing.T) {
func TestSocialGoogle_Reload(t *testing.T) { func TestSocialGoogle_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -750,6 +751,14 @@ func TestSocialGoogle_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/google",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -769,6 +778,11 @@ func TestSocialGoogle_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/google",
},
}, },
} }
@@ -779,10 +793,12 @@ func TestSocialGoogle_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
}) })
} }
} }

View File

@@ -39,9 +39,8 @@ func NewGrafanaComProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings
info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token" info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token"
info.AuthStyle = "inheader" info.AuthStyle = "inheader"
config := createOAuthConfig(info, cfg, social.GrafanaComProviderName)
provider := &SocialGrafanaCom{ provider := &SocialGrafanaCom{
SocialBase: newSocialBase(social.GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.GrafanaComProviderName, info, features, cfg),
url: cfg.GrafanaComURL, url: cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
} }
@@ -69,6 +68,28 @@ func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOS
return nil return nil
} }
func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
// Override necessary settings
newInfo.AuthUrl = s.cfg.GrafanaComURL + "/oauth2/authorize"
newInfo.TokenUrl = s.cfg.GrafanaComURL + "/api/oauth2/token"
newInfo.AuthStyle = "inheader"
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GrafanaComProviderName, newInfo, s.features, s.cfg)
s.url = s.cfg.GrafanaComURL
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil
}
func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool {
return true return true
} }

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/featuremgmt"
@@ -195,11 +196,12 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
const GrafanaComURL = "http://localhost:3000" const GrafanaComURL = "http://localhost:3000"
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -219,6 +221,19 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
ClientId: "new-client-id", ClientId: "new-client-id",
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
Name: "a-new-name", Name: "a-new-name",
AuthUrl: GrafanaComURL + "/oauth2/authorize",
TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: GrafanaComURL + "/oauth2/authorize",
TokenURL: GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: "/login/grafana_com",
}, },
}, },
{ {
@@ -243,6 +258,16 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
TokenUrl: GrafanaComURL + "/api/oauth2/token", TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader", AuthStyle: "inheader",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: GrafanaComURL + "/oauth2/authorize",
TokenURL: GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: "/login/grafana_com",
},
}, },
} }
@@ -256,10 +281,68 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestSocialGrafanaCom_Reload_ExtraFields(t *testing.T) {
const GrafanaComURL = "http://localhost:3000"
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
}{
{
name: "successfully reloads the allowed organizations when they are set in the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
AuthUrl: GrafanaComURL + "/oauth2/authorize",
TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg := &setting.Cfg{
GrafanaComURL: GrafanaComURL,
}
s := NewGrafanaComProvider(tc.info, cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
}) })
} }
} }

View File

@@ -45,13 +45,12 @@ type OktaClaims struct {
} }
func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta { func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta {
config := createOAuthConfig(info, cfg, social.OktaProviderName)
provider := &SocialOkta{ provider := &SocialOkta{
SocialBase: newSocialBase(social.OktaProviderName, config, info, cfg.AutoAssignOrgRole, features), SocialBase: newSocialBase(social.OktaProviderName, info, features, cfg),
} }
if info.UseRefreshToken { if info.UseRefreshToken {
appendUniqueScope(config, social.OfflineAccessScope) appendUniqueScope(provider.Config, social.OfflineAccessScope)
} }
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -77,6 +76,23 @@ func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSetting
return nil return nil
} }
func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.OktaProviderName, newInfo, s.features, s.cfg)
if newInfo.UseRefreshToken {
appendUniqueScope(s.Config, social.OfflineAccessScope)
}
return nil
}
func (claims *OktaClaims) extractEmail() string { func (claims *OktaClaims) extractEmail() string {
if claims.Email == "" && claims.PreferredUsername != "" { if claims.Email == "" && claims.PreferredUsername != "" {
return claims.PreferredUsername return claims.PreferredUsername

View File

@@ -193,11 +193,12 @@ func TestSocialOkta_Validate(t *testing.T) {
func TestSocialOkta_Reload(t *testing.T) { func TestSocialOkta_Reload(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
info *social.OAuthInfo info *social.OAuthInfo
settings ssoModels.SSOSettings settings ssoModels.SSOSettings
expectError bool expectError bool
expectedInfo *social.OAuthInfo expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{ }{
{ {
name: "SSO provider successfully updated", name: "SSO provider successfully updated",
@@ -218,6 +219,14 @@ func TestSocialOkta_Reload(t *testing.T) {
ClientSecret: "new-client-secret", ClientSecret: "new-client-secret",
AuthUrl: "some-new-url", AuthUrl: "some-new-url",
}, },
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/okta",
},
}, },
{ {
name: "fails if settings contain invalid values", name: "fails if settings contain invalid values",
@@ -237,6 +246,11 @@ func TestSocialOkta_Reload(t *testing.T) {
ClientId: "client-id", ClientId: "client-id",
ClientSecret: "client-secret", ClientSecret: "client-secret",
}, },
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/okta",
},
}, },
} }
@@ -247,10 +261,12 @@ func TestSocialOkta_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings) err := s.Reload(context.Background(), tc.settings)
if tc.expectError { if tc.expectError {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info) require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
}) })
} }
} }

View File

@@ -3,7 +3,6 @@ package connectors
import ( import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -21,32 +20,31 @@ import (
"github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings"
ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting"
) )
type SocialBase struct { type SocialBase struct {
*oauth2.Config *oauth2.Config
info *social.OAuthInfo info *social.OAuthInfo
infoMutex sync.RWMutex cfg *setting.Cfg
log log.Logger reloadMutex sync.RWMutex
autoAssignOrgRole string log log.Logger
features featuremgmt.FeatureToggles features featuremgmt.FeatureToggles
} }
func newSocialBase(name string, func newSocialBase(name string,
config *oauth2.Config,
info *social.OAuthInfo, info *social.OAuthInfo,
autoAssignOrgRole string,
features featuremgmt.FeatureToggles, features featuremgmt.FeatureToggles,
cfg *setting.Cfg,
) *SocialBase { ) *SocialBase {
logger := log.New("oauth." + name) logger := log.New("oauth." + name)
return &SocialBase{ return &SocialBase{
Config: config, Config: createOAuthConfig(info, cfg, name),
info: info, info: info,
log: logger, log: logger,
autoAssignOrgRole: autoAssignOrgRole, features: features,
features: features, cfg: cfg,
} }
} }
@@ -60,7 +58,7 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin)) bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin))
bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup)) bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup))
bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains)) bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains))
bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole)) bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.cfg.AutoAssignOrgRole))
bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath)) bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath))
bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict)) bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict))
bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync)) bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync))
@@ -76,26 +74,12 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
} }
func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo { func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo {
s.infoMutex.RLock() s.reloadMutex.RLock()
defer s.infoMutex.RUnlock() defer s.reloadMutex.RUnlock()
return s.info return s.info
} }
func (s *SocialBase) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return fmt.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.infoMutex.Lock()
defer s.infoMutex.Unlock()
s.info = info
return nil
}
func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) { func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) {
if s.info.RoleAttributePath == "" { if s.info.RoleAttributePath == "" {
if s.info.RoleAttributeStrict { if s.info.RoleAttributeStrict {
@@ -145,9 +129,9 @@ func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType,
// defaultRole returns the default role for the user based on the autoAssignOrgRole setting // defaultRole returns the default role for the user based on the autoAssignOrgRole setting
// if legacy is enabled "" is returned indicating the previous role assignment is used. // if legacy is enabled "" is returned indicating the previous role assignment is used.
func (s *SocialBase) defaultRole() org.RoleType { func (s *SocialBase) defaultRole() org.RoleType {
if s.autoAssignOrgRole != "" { if s.cfg.AutoAssignOrgRole != "" {
s.log.Debug("No role found, returning default.") s.log.Debug("No role found, returning default.")
return org.RoleType(s.autoAssignOrgRole) return org.RoleType(s.cfg.AutoAssignOrgRole)
} }
// should never happen // should never happen

View File

@@ -140,18 +140,8 @@ func ProvideService(
} }
for name := range socialService.GetOAuthProviders() { for name := range socialService.GetOAuthProviders() {
oauthCfg := socialService.GetOAuthInfoProvider(name) clientName := authn.ClientWithPrefix(name)
if oauthCfg != nil && oauthCfg.Enabled { s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthTokenService, socialService))
clientName := authn.ClientWithPrefix(name)
connector, errConnector := socialService.GetConnector(name)
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
if errConnector != nil || errHTTPClient != nil {
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
} else {
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
}
}
} }
// FIXME (jguer): move to User package // FIXME (jguer): move to User package

View File

@@ -8,7 +8,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
@@ -40,6 +39,9 @@ const (
) )
var ( var (
errOAuthClientDisabled = errutil.BadRequest("auth.oauth.disabled", errutil.WithPublicMessage("OAuth client is disabled"))
errOAuthInternal = errutil.Internal("auth.oauth.internal", errutil.WithPublicMessage("An internal error occurred in the OAuth client"))
errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred")) errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie")) errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
@@ -62,24 +64,25 @@ var _ authn.LogoutClient = new(OAuth)
var _ authn.RedirectClient = new(OAuth) var _ authn.RedirectClient = new(OAuth)
func ProvideOAuth( func ProvideOAuth(
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo, name string, cfg *setting.Cfg, oauthService oauthtoken.OAuthTokenService,
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService, socialService social.Service,
) *OAuth { ) *OAuth {
providerName := strings.TrimPrefix(name, "auth.client.")
return &OAuth{ return &OAuth{
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")), name, fmt.Sprintf("oauth_%s", providerName), providerName,
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService, log.New(name), cfg, oauthService, socialService,
} }
} }
type OAuth struct { type OAuth struct {
name string name string
moduleName string moduleName string
providerName string
log log.Logger log log.Logger
cfg *setting.Cfg cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector oauthService oauthtoken.OAuthTokenService
httpClient *http.Client socialService social.Service
oauthService oauthtoken.OAuthTokenService
} }
func (c *OAuth) Name() string { func (c *OAuth) Name() string {
@@ -88,6 +91,12 @@ func (c *OAuth) Name() string {
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) { func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName) r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
// get hashed state stored in cookie // get hashed state stored in cookie
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName) stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
if err != nil { if err != nil {
@@ -99,7 +108,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
} }
// get state returned by the idp and hash it // get state returned by the idp and hash it
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret) stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, oauthCfg.ClientSecret)
// compare the state returned by idp against the one we stored in cookie // compare the state returned by idp against the one we stored in cookie
if stateQuery != stateCookie.Value { if stateQuery != stateCookie.Value {
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state") return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
@@ -107,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
var opts []oauth2.AuthCodeOption var opts []oauth2.AuthCodeOption
// if pkce is enabled for client validate we have the cookie and set it as url param // if pkce is enabled for client validate we have the cookie and set it as url param
if c.oauthCfg.UsePKCE { if oauthCfg.UsePKCE {
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName) pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
if err != nil { if err != nil {
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err) return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
@@ -115,15 +124,21 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
opts = append(opts, oauth2.VerifierOption(pkceCookie.Value)) opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
} }
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) connector, errConnector := c.socialService.GetConnector(c.providerName)
httpClient, errHTTPClient := c.socialService.GetOAuthHttpClient(c.providerName)
if errConnector != nil || errHTTPClient != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth client: %w", c.name, errors.Join(errConnector, errHTTPClient))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// exchange auth code to a valid token // exchange auth code to a valid token
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...) token, err := connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
if err != nil { if err != nil {
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err) return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
} }
token.TokenType = "Bearer" token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token) userInfo, err := connector.UserInfo(ctx, connector.Client(clientCtx, token), token)
if err != nil { if err != nil {
var sErr *connectors.SocialError var sErr *connectors.SocialError
if errors.As(err, &sErr) { if errors.As(err, &sErr) {
@@ -136,7 +151,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided") return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
} }
if !c.connector.IsEmailAllowed(userInfo.Email) { if !connector.IsEmailAllowed(userInfo.Email) {
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed") return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
} }
@@ -167,7 +182,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
SyncTeams: true, SyncTeams: true,
FetchSyncedUser: true, FetchSyncedUser: true,
SyncPermissions: true, SyncPermissions: true,
AllowSignUp: c.connector.IsSignupAllowed(), AllowSignUp: connector.IsSignupAllowed(),
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed // skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
SyncOrgRoles: len(orgRoles) > 0, SyncOrgRoles: len(orgRoles) > 0,
LookUpParams: lookupParams, LookUpParams: lookupParams,
@@ -178,12 +193,17 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) { func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
var opts []oauth2.AuthCodeOption var opts []oauth2.AuthCodeOption
if c.oauthCfg.HostedDomain != "" { oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain)) if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
if oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, oauthCfg.HostedDomain))
} }
var plainPKCE string var plainPKCE string
if c.oauthCfg.UsePKCE { if oauthCfg.UsePKCE {
verifier, err := genPKCECodeVerifier() verifier, err := genPKCECodeVerifier()
if err != nil { if err != nil {
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err) return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
@@ -193,13 +213,18 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
opts = append(opts, oauth2.S256ChallengeOption(plainPKCE)) opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
} }
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret) state, hashedSate, err := genOAuthState(c.cfg.SecretKey, oauthCfg.ClientSecret)
if err != nil { if err != nil {
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err) return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
} }
connector, err := c.socialService.GetConnector(c.providerName)
if err != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth connector: %w", c.name, err)
}
return &authn.Redirect{ return &authn.Redirect{
URL: c.connector.AuthCodeURL(state, opts...), URL: connector.AuthCodeURL(state, opts...),
Extra: map[string]string{ Extra: map[string]string{
authn.KeyOAuthState: hashedSate, authn.KeyOAuthState: hashedSate,
authn.KeyOAuthPKCE: plainPKCE, authn.KeyOAuthPKCE: plainPKCE,
@@ -215,19 +240,25 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err) c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
} }
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg) oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if redirctURL == "" { if !oauthCfg.Enabled {
c.log.FromContext(ctx).Debug("OAuth client is disabled")
return nil, false
}
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
if redirectURL == "" {
c.log.FromContext(ctx).Debug("No signout redirect url configured") c.log.FromContext(ctx).Debug("No signout redirect url configured")
return nil, false return nil, false
} }
if isOICDLogout(redirctURL) && token != nil && token.Valid() { if isOICDLogout(redirectURL) && token != nil && token.Valid() {
if idToken, ok := token.Extra("id_token").(string); ok { if idToken, ok := token.Extra("id_token").(string); ok {
redirctURL = withIDTokenHint(redirctURL, idToken) redirectURL = withIDTokenHint(redirectURL, idToken)
} }
} }
return &authn.Redirect{URL: redirctURL}, true return &authn.Redirect{URL: redirectURL}, true
} }
// genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string. // genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
@@ -46,17 +47,23 @@ func TestOAuth_Authenticate(t *testing.T) {
{ {
desc: "should return error when missing state cookie", desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{}, oauthCfg: &social.OAuthInfo{Enabled: true},
expectedErr: errOAuthMissingState, expectedErr: errOAuthMissingState,
}, },
{ {
desc: "should return error when state cookie is present but don't have a value", desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{}, oauthCfg: &social.OAuthInfo{Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "", stateCookieValue: "",
expectedErr: errOAuthMissingState, expectedErr: errOAuthMissingState,
}, },
{
desc: "should return error when the client is not enabled",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: false},
expectedErr: errOAuthClientDisabled,
},
{ {
desc: "should return error when state from ipd does not match stored state", desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{ req: &authn.Request{HTTPRequest: &http.Request{
@@ -64,7 +71,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-other-state"), URL: mustParseURL("http://grafana.com/?state=some-other-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState, expectedErr: errOAuthInvalidState,
@@ -76,7 +83,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE, expectedErr: errOAuthMissingPKCE,
@@ -88,7 +95,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
addPKCECookie: true, addPKCECookie: true,
@@ -103,7 +110,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
addPKCECookie: true, addPKCECookie: true,
@@ -119,7 +126,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
addPKCECookie: true, addPKCECookie: true,
@@ -157,7 +164,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"), URL: mustParseURL("http://grafana.com/?state=some-state"),
}, },
}, },
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
allowInsecureTakeover: true, allowInsecureTakeover: true,
addStateCookie: true, addStateCookie: true,
stateCookieValue: "some-state", stateCookieValue: "some-state",
@@ -211,12 +218,18 @@ func TestOAuth_Authenticate(t *testing.T) {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue}) tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
} }
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{ fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedUserInfo: tt.userInfo, ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedToken: &oauth2.Token{}, ExpectedConnector: fakeConnector{
ExpectedIsSignupAllowed: true, ExpectedUserInfo: tt.userInfo,
ExpectedIsEmailAllowed: tt.isEmailAllowed, ExpectedToken: &oauth2.Token{},
}, nil, nil) ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc)
identity, err := c.Authenticate(context.Background(), tt.req) identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr) assert.ErrorIs(t, err, tt.expectedErr)
@@ -256,21 +269,27 @@ func TestOAuth_RedirectURL(t *testing.T) {
tests := []testCase{ tests := []testCase{
{ {
desc: "should generate redirect url and state", desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{}, oauthCfg: &social.OAuthInfo{Enabled: true},
authCodeUrlCalled: true, authCodeUrlCalled: true,
}, },
{ {
desc: "should generate redirect url with hosted domain option if configured", desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"}, oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true},
numCallOptions: 1, numCallOptions: 1,
authCodeUrlCalled: true, authCodeUrlCalled: true,
}, },
{ {
desc: "should generate redirect url with pkce if configured", desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true}, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
numCallOptions: 1, numCallOptions: 1,
authCodeUrlCalled: true, authCodeUrlCalled: true,
}, },
{
desc: "should return error if the client is not enabled",
oauthCfg: &social.OAuthInfo{Enabled: false},
authCodeUrlCalled: false,
expectedErr: errOAuthClientDisabled,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -279,13 +298,18 @@ func TestOAuth_RedirectURL(t *testing.T) {
authCodeUrlCalled = false authCodeUrlCalled = false
) )
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{ fakeSocialSvc := &socialtest.FakeSocialService{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string { ExpectedAuthInfoProvider: tt.oauthCfg,
authCodeUrlCalled = true ExpectedConnector: mockConnector{
require.Len(t, opts, tt.numCallOptions) AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
return "" authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
}, },
}, nil, nil) }
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), nil, fakeSocialSvc)
redirect, err := c.RedirectURL(context.Background(), nil) redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr) assert.ErrorIs(t, err, tt.expectedErr)
@@ -321,12 +345,17 @@ func TestOAuth_Logout(t *testing.T) {
cfg: &setting.Cfg{}, cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{}, oauthCfg: &social.OAuthInfo{},
}, },
{
desc: "should not return redirect url when client is not enabled",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{Enabled: false},
},
{ {
desc: "should return redirect url for globably configured redirect url", desc: "should return redirect url for globably configured redirect url",
cfg: &setting.Cfg{ cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout", SignoutRedirectUrl: "http://idp.com/logout",
}, },
oauthCfg: &social.OAuthInfo{}, oauthCfg: &social.OAuthInfo{Enabled: true},
expectedURL: "http://idp.com/logout", expectedURL: "http://idp.com/logout",
expectedOK: true, expectedOK: true,
}, },
@@ -334,6 +363,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should return redirect url for client configured redirect url", desc: "should return redirect url for client configured redirect url",
cfg: &setting.Cfg{}, cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{ oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout", SignoutRedirectUrl: "http://idp.com/logout",
}, },
expectedURL: "http://idp.com/logout", expectedURL: "http://idp.com/logout",
@@ -345,6 +375,7 @@ func TestOAuth_Logout(t *testing.T) {
SignoutRedirectUrl: "http://idp.com/logout", SignoutRedirectUrl: "http://idp.com/logout",
}, },
oauthCfg: &social.OAuthInfo{ oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp-2.com/logout", SignoutRedirectUrl: "http://idp-2.com/logout",
}, },
expectedURL: "http://idp-2.com/logout", expectedURL: "http://idp-2.com/logout",
@@ -354,6 +385,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should add id token hint if oicd logout is configured and token is valid", desc: "should add id token hint if oicd logout is configured and token is valid",
cfg: &setting.Cfg{}, cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{ oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin", SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
}, },
expectedURL: "http://idp.com/logout", expectedURL: "http://idp.com/logout",
@@ -387,7 +419,10 @@ func TestOAuth_Logout(t *testing.T) {
}, },
} }
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService) fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc)
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{}) redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})

View File

@@ -183,7 +183,7 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
settings.Provider = key settings.Provider = key
err := api.SSOSettingsService.Upsert(c.Req.Context(), settings) err := api.SSOSettingsService.Upsert(c.Req.Context(), &settings)
if err != nil { if err != nil {
return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err) return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err)
} }

View File

@@ -134,7 +134,7 @@ func TestSSOSettingsAPI_Update(t *testing.T) {
service := ssosettingstests.NewMockService(t) service := ssosettingstests.NewMockService(t)
if tt.expectedServiceCall { if tt.expectedServiceCall {
service.On("Upsert", mock.Anything, settings).Return(tt.expectedError).Once() service.On("Upsert", mock.Anything, &settings).Return(tt.expectedError).Once()
} }
server := setupTests(t, service) server := setupTests(t, service)

View File

@@ -84,7 +84,7 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err
return result, nil return result, nil
} }
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error { func (s *SSOSettingsStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if settings.Provider == "" { if settings.Provider == "" {
return ssosettings.ErrNotFound return ssosettings.ErrNotFound
} }
@@ -110,13 +110,10 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin
} }
_, err = sess.UseBool(isDeletedColumn).Update(updated, existing) _, err = sess.UseBool(isDeletedColumn).Update(updated, existing)
} else { } else {
_, err = sess.Insert(&models.SSOSettings{ settings.ID = uuid.New().String()
ID: uuid.New().String(), settings.Created = now
Provider: settings.Provider, settings.Updated = now
Settings: settings.Settings, _, err = sess.Insert(settings)
Created: now,
Updated: now,
})
} }
return err return err

View File

@@ -105,7 +105,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
}, },
} }
err := ssoSettingsStore.Upsert(context.Background(), settings) err := ssoSettingsStore.Upsert(context.Background(), &settings)
require.NoError(t, err) require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false) actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
@@ -143,7 +143,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
"client_secret": "this-is-a-new-secret", "client_secret": "this-is-a-new-secret",
}, },
} }
err = ssoSettingsStore.Upsert(context.Background(), newSettings) err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err) require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false) actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
@@ -181,7 +181,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
}, },
} }
err = ssoSettingsStore.Upsert(context.Background(), newSettings) err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err) require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false) actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
@@ -217,7 +217,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
"client_secret": "this-is-my-new-secret", "client_secret": "this-is-my-new-secret",
}, },
} }
err = ssoSettingsStore.Upsert(context.Background(), newSettings) err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err) require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false) actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
@@ -254,7 +254,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
}, },
} }
err = ssoSettingsStore.Upsert(context.Background(), settings) err = ssoSettingsStore.Upsert(context.Background(), &settings)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, ssosettings.ErrNotFound) require.ErrorIs(t, err, ssosettings.ErrNotFound)
}) })

View File

@@ -9,7 +9,7 @@ import (
var ( var (
ErrNotFound = errors.New("not found") ErrNotFound = errors.New("not found")
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("provider is invalid")) ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("Provider is invalid"))
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("settings field is invalid")) ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("Settings field is invalid"))
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("settings.clientId cannot be empty")) ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("ClientId cannot be empty"))
) )

View File

@@ -28,7 +28,7 @@ type Service interface {
// GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted // GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted
GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error) GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error)
// Upsert creates or updates the SSO settings for a given provider // Upsert creates or updates the SSO settings for a given provider
Upsert(ctx context.Context, settings models.SSOSettings) error Upsert(ctx context.Context, settings *models.SSOSettings) error
// Delete deletes the SSO settings for a given provider (soft delete) // Delete deletes the SSO settings for a given provider (soft delete)
Delete(ctx context.Context, provider string) error Delete(ctx context.Context, provider string) error
// Patch updates the specified SSO settings (key-value pairs) for a given provider // Patch updates the specified SSO settings (key-value pairs) for a given provider
@@ -62,6 +62,6 @@ type FallbackStrategy interface {
type Store interface { type Store interface {
Get(ctx context.Context, provider string) (*models.SSOSettings, error) Get(ctx context.Context, provider string) (*models.SSOSettings, error)
List(ctx context.Context) ([]*models.SSOSettings, error) List(ctx context.Context) ([]*models.SSOSettings, error)
Upsert(ctx context.Context, settings models.SSOSettings) error Upsert(ctx context.Context, settings *models.SSOSettings) error
Delete(ctx context.Context, provider string) error Delete(ctx context.Context, provider string) error
} }

View File

@@ -147,7 +147,7 @@ func (s *SSOSettingsService) ListWithRedactedSecrets(ctx context.Context) ([]*mo
return storeSettings, nil return storeSettings, nil
} }
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error { func (s *SSOSettingsService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if !isProviderConfigurable(settings.Provider) { if !isProviderConfigurable(settings.Provider) {
return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider) return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider)
} }
@@ -157,7 +157,7 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider) return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider)
} }
err := social.Validate(ctx, settings) err := social.Validate(ctx, *settings)
if err != nil { if err != nil {
return err return err
} }
@@ -167,6 +167,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
return err return err
} }
secrets := collectSecrets(settings, storedSettings)
settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings) settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings)
if err != nil { if err != nil {
return err return err
@@ -178,7 +180,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
} }
go func() { go func() {
err = social.Reload(context.Background(), settings) settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets)
err = social.Reload(context.Background(), *settings)
if err != nil { if err != nil {
s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err) s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
} }
@@ -249,7 +252,7 @@ func (s *SSOSettingsService) getFallbackStrategyFor(provider string) (ssosetting
func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) { func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) {
result := make(map[string]any) result := make(map[string]any)
for k, v := range settings { for k, v := range settings {
if isSecret(k) { if isSecret(k) && v != "" {
strValue, ok := v.(string) strValue, ok := v.(string)
if !ok { if !ok {
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v) return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
@@ -315,20 +318,24 @@ func (s *SSOSettingsService) doReload(ctx context.Context) {
// mergeSSOSettings merges the settings from the database with the system settings // mergeSSOSettings merges the settings from the database with the system settings
// Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings) // Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings)
// and the rest of the settings are loaded from the system settings // and the rest of the settings have to be loaded from the system settings
func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings { func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings {
if dbSettings == nil { if dbSettings == nil {
s.logger.Debug("No SSO Settings found in the database, using system settings") s.logger.Debug("No SSO Settings found in the database, using system settings")
return systemSettings return systemSettings
} }
s.logger.Debug("Merging SSO Settings", "dbSettings", dbSettings.Settings, "systemSettings", systemSettings.Settings) s.logger.Debug("Merging SSO Settings", "dbSettings", removeSecrets(dbSettings.Settings), "systemSettings", removeSecrets(systemSettings.Settings))
finalSettings := mergeSettings(dbSettings.Settings, systemSettings.Settings) result := &models.SSOSettings{
Provider: dbSettings.Provider,
Source: dbSettings.Source,
Settings: mergeSettings(dbSettings.Settings, systemSettings.Settings),
Created: dbSettings.Created,
Updated: dbSettings.Updated,
}
dbSettings.Settings = finalSettings return result
return dbSettings
} }
func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) { func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
@@ -358,6 +365,22 @@ func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[st
return settings, nil return settings, nil
} }
// removeSecrets removes all the secrets from the map and replaces them with a redacted password
// and returns a new map
func removeSecrets(settings map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range settings {
if isSecret(k) {
result[k] = setting.RedactedPassword
continue
}
result[k] = v
}
return result
}
// mergeSettings merges two maps in a way that the values from the first map are preserved
// and the values from the second map are added only if they don't exist in the first map
func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any { func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any {
settings := make(map[string]any) settings := make(map[string]any)
@@ -374,6 +397,32 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
return settings return settings
} }
// collectSecrets collects all the secrets from the request and the currently stored settings
// and returns a new map
func collectSecrets(settings *models.SSOSettings, storedSettings *models.SSOSettings) map[string]any {
secrets := map[string]any{}
for k, v := range settings.Settings {
if isSecret(k) {
if isNewSecretValue(v.(string)) {
secrets[k] = v.(string) // use the new value
continue
}
secrets[k] = storedSettings.Settings[k] // keep the currently stored value
}
}
return secrets
}
func overrideMaps(maps ...map[string]any) map[string]any {
result := make(map[string]any)
for _, m := range maps {
for k, v := range m {
result[k] = v
}
}
return result
}
func isSecret(fieldName string) bool { func isSecret(fieldName string) bool {
secretFieldPatterns := []string{"secret"} secretFieldPatterns := []string{"secret"}

View File

@@ -5,7 +5,10 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"maps"
"sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -772,16 +775,50 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
}, },
IsDeleted: false, IsDeleted: false,
} }
var wg sync.WaitGroup
wg.Add(1)
reloadable := ssosettingstests.NewMockReloadable(t) reloadable := ssosettingstests.NewMockReloadable(t)
reloadable.On("Validate", mock.Anything, settings).Return(nil) reloadable.On("Validate", mock.Anything, settings).Return(nil)
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe() reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool {
wg.Done()
return settings.Provider == provider &&
settings.ID == "someid" &&
maps.Equal(settings.Settings, map[string]any{
"client_id": "client-id",
"client_secret": "client-secret",
"enabled": true,
})
})).Return(nil).Once()
env.reloadables[provider] = reloadable env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
env.secrets.On("Decrypt", mock.Anything, []byte("encrypted-current-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings) env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
currentTime := time.Now()
settings.ID = "someid"
settings.Created = currentTime
settings.Updated = currentTime
env.store.ActualSSOSettings = *settings
return nil
}
env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) {
return &models.SSOSettings{
ID: "someid",
Provider: provider,
Settings: map[string]any{
"client_secret": base64.RawStdEncoding.EncodeToString([]byte("encrypted-current-client-secret")),
},
}, nil
}
err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err) require.NoError(t, err)
// Wait for the goroutine first to assert the Reload call
wg.Wait()
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
require.EqualValues(t, settings, env.store.ActualSSOSettings) require.EqualValues(t, settings, env.store.ActualSSOSettings)
}) })
@@ -790,7 +827,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env := setupTestEnv(t) env := setupTestEnv(t)
provider := social.GrafanaComProviderName provider := social.GrafanaComProviderName
settings := models.SSOSettings{ settings := &models.SSOSettings{
Provider: provider, Provider: provider,
Settings: map[string]any{ Settings: map[string]any{
"client_id": "client-id", "client_id": "client-id",
@@ -811,7 +848,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env := setupTestEnv(t) env := setupTestEnv(t)
provider := social.AzureADProviderName provider := social.AzureADProviderName
settings := models.SSOSettings{ settings := &models.SSOSettings{
Provider: provider, Provider: provider,
Settings: map[string]any{ Settings: map[string]any{
"client_id": "client-id", "client_id": "client-id",
@@ -847,14 +884,14 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed")) reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed"))
env.reloadables[provider] = reloadable env.reloadables[provider] = reloadable
err := env.service.Upsert(context.Background(), settings) err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err) require.Error(t, err)
}) })
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) { t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
env := setupTestEnv(t) env := setupTestEnv(t)
settings := models.SSOSettings{ settings := &models.SSOSettings{
Provider: social.AzureADProviderName, Provider: social.AzureADProviderName,
Settings: map[string]any{ Settings: map[string]any{
"client_id": "client-id", "client_id": "client-id",
@@ -889,7 +926,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.reloadables[provider] = reloadable env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once() env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
err := env.service.Upsert(context.Background(), settings) err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err) require.Error(t, err)
}) })
@@ -921,7 +958,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once() env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings) err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err) require.NoError(t, err)
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret")) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret"))
@@ -950,11 +987,11 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
return &models.SSOSettings{}, nil return &models.SSOSettings{}, nil
} }
env.store.UpsertFn = func(ctx context.Context, settings models.SSOSettings) error { env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
return errors.New("failed to upsert settings") return errors.New("failed to upsert settings")
} }
err := env.service.Upsert(context.Background(), settings) err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err) require.Error(t, err)
}) })
@@ -978,7 +1015,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.reloadables[provider] = reloadable env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings) err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err) require.NoError(t, err)
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
@@ -1098,6 +1135,22 @@ func TestSSOSettingsService_decryptSecrets(t *testing.T) {
"other_secret": "decrypted-other-secret", "other_secret": "decrypted-other-secret",
}, },
}, },
{
name: "should not decrypt when a secret is empty",
setup: func(env testEnv) {
env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once()
},
settings: map[string]any{
"enabled": true,
"client_secret": "",
"other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")),
},
want: map[string]any{
"enabled": true,
"client_secret": "",
"other_secret": "decrypted-other-secret",
},
},
{ {
name: "should return an error if data is not a string", name: "should return an error if data is not a string",
settings: map[string]any{ settings: map[string]any{

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.38.0. DO NOT EDIT. // Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests package ssosettingstests

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.38.0. DO NOT EDIT. // Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests package ssosettingstests
@@ -183,7 +183,7 @@ func (_m *MockService) Reload(ctx context.Context, provider string) {
} }
// Upsert provides a mock function with given fields: ctx, settings // Upsert provides a mock function with given fields: ctx, settings
func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error { func (_m *MockService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
ret := _m.Called(ctx, settings) ret := _m.Called(ctx, settings)
if len(ret) == 0 { if len(ret) == 0 {
@@ -191,7 +191,7 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings)
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
r0 = rf(ctx, settings) r0 = rf(ctx, settings)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)

View File

@@ -17,7 +17,7 @@ type FakeStore struct {
ActualSSOSettings models.SSOSettings ActualSSOSettings models.SSOSettings
GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error) GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error)
UpsertFn func(ctx context.Context, settings models.SSOSettings) error UpsertFn func(ctx context.Context, settings *models.SSOSettings) error
} }
func NewFakeStore() *FakeStore { func NewFakeStore() *FakeStore {
@@ -35,12 +35,12 @@ func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
return f.ExpectedSSOSettings, f.ExpectedError return f.ExpectedSSOSettings, f.ExpectedError
} }
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error { func (f *FakeStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if f.UpsertFn != nil { if f.UpsertFn != nil {
return f.UpsertFn(ctx, settings) return f.UpsertFn(ctx, settings)
} }
f.ActualSSOSettings = settings f.ActualSSOSettings = *settings
return f.ExpectedError return f.ExpectedError
} }

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.39.2. DO NOT EDIT. // Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests package ssosettingstests
@@ -93,7 +93,7 @@ func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
} }
// Upsert provides a mock function with given fields: ctx, settings // Upsert provides a mock function with given fields: ctx, settings
func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error { func (_m *MockStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
ret := _m.Called(ctx, settings) ret := _m.Called(ctx, settings)
if len(ret) == 0 { if len(ret) == 0 {
@@ -101,7 +101,7 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
r0 = rf(ctx, settings) r0 = rf(ctx, settings)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)

View File

@@ -2,6 +2,7 @@ package strategies
import ( import (
"context" "context"
"maps"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings"
@@ -31,7 +32,10 @@ func (s *OAuthStrategy) IsMatch(provider string) bool {
} }
func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) { func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) {
return s.settingsByProvider[provider], nil providerConfig := s.settingsByProvider[provider]
result := make(map[string]any, len(providerConfig))
maps.Copy(result, providerConfig)
return result, nil
} }
func (s *OAuthStrategy) loadAllSettings() { func (s *OAuthStrategy) loadAllSettings() {