AuthN: Support reloading SSO config after the sso settings have changed (#80734)

* Add AuthNSvc reload handling

* Working, need to add test

* Remove commented out code

* Add Reload implementation to connectors

* Align and add tests, refactor

* Add more tests, linting

* Add extra checks + tests to oauth client

* Clean up based on reviews

* Move config instantiation into newSocialBase

* Use specific error
This commit is contained in:
Misi
2024-01-22 14:54:48 +01:00
committed by GitHub
parent 1f4a520b9d
commit 20bb0a3ab1
31 changed files with 889 additions and 217 deletions

View File

@@ -73,16 +73,15 @@ type keySetJWKS struct {
}
func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD {
config := createOAuthConfig(info, cfg, social.AzureADProviderName)
provider := &SocialAzureAD{
SocialBase: newSocialBase(social.AzureADProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.AzureADProviderName, info, features, cfg),
cache: cache,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false),
}
if info.UseRefreshToken {
appendUniqueScope(config, social.OfflineAccessScope)
appendUniqueScope(provider.Config, social.OfflineAccessScope)
}
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -164,6 +163,27 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
}, nil
}
func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.AzureADProviderName, newInfo, s.features, s.cfg)
if newInfo.UseRefreshToken {
appendUniqueScope(s.Config, social.OfflineAccessScope)
}
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
s.forceUseGraphAPI = MustBool(newInfo.Extra[forceUseGraphAPIKey], false)
return nil
}
func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error {
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {

View File

@@ -1053,6 +1053,7 @@ func TestSocialAzureAD_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -1073,6 +1074,14 @@ func TestSocialAzureAD_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/azuread",
},
},
{
name: "fails if settings contain invalid values",
@@ -1092,6 +1101,11 @@ func TestSocialAzureAD_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/azuread",
},
},
}
@@ -1102,10 +1116,65 @@ func TestSocialAzureAD_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
expectedForceUseGraphApi bool
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
"force_use_graph_api": "true",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
expectedForceUseGraphApi: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewAzureADProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedForceUseGraphApi, s.forceUseGraphAPI)
})
}
}

View File

@@ -46,9 +46,8 @@ type SocialGenericOAuth struct {
}
func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth {
config := createOAuthConfig(info, cfg, social.GenericOAuthProviderName)
provider := &SocialGenericOAuth{
SocialBase: newSocialBase(social.GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.GenericOAuthProviderName, info, features, cfg),
teamsUrl: info.TeamsUrl,
emailAttributeName: info.EmailAttributeName,
emailAttributePath: info.EmailAttributePath,
@@ -84,6 +83,31 @@ func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SS
return nil
}
func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GenericOAuthProviderName, newInfo, s.features, s.cfg)
s.teamsUrl = newInfo.TeamsUrl
s.emailAttributeName = newInfo.EmailAttributeName
s.emailAttributePath = newInfo.EmailAttributePath
s.nameAttributePath = newInfo.Extra[nameAttributePathKey]
s.groupsAttributePath = newInfo.GroupsAttributePath
s.loginAttributePath = newInfo.Extra[loginAttributePathKey]
s.idTokenAttributeName = newInfo.Extra[idTokenAttributeNameKey]
s.teamIdsAttributePath = newInfo.TeamIdsAttributePath
s.teamIds = util.SplitString(newInfo.Extra[teamIdsKey])
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil
}
// TODOD: remove this in the next PR and use the isGroupMember from social.go
func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
info := s.GetOAuthInfo()

View File

@@ -981,6 +981,7 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -1001,6 +1002,14 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/generic_oauth",
},
},
{
name: "fails if settings contain invalid values",
@@ -1020,6 +1029,11 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/generic_oauth",
},
},
}
@@ -1030,10 +1044,116 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestGenericOAuth_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedTeamsUrl string
expectedEmailAttributeName string
expectedEmailAttributePath string
expectedNameAttributePath string
expectedGroupsAttributePath string
expectedLoginAttributePath string
expectedIdTokenAttributeName string
expectedTeamIdsAttributePath string
expectedTeamIds []string
expectedAllowedOrganizations []string
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
TeamsUrl: "https://host/users",
EmailAttributePath: "email-attr-path",
EmailAttributeName: "email-attr-name",
GroupsAttributePath: "groups-attr-path",
TeamIdsAttributePath: "team-ids-attr-path",
Extra: map[string]string{
teamIdsKey: "team1",
allowedOrganizationsKey: "org1",
loginAttributePathKey: "login-attr-path",
idTokenAttributeNameKey: "id-token-attr-name",
nameAttributePathKey: "name-attr-path",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"teams_url": "https://host/v2/users",
"email_attribute_path": "new-email-attr-path",
"email_attribute_name": "new-email-attr-name",
"groups_attribute_path": "new-group-attr-path",
"team_ids_attribute_path": "new-team-ids-attr-path",
teamIdsKey: "team1,team2",
allowedOrganizationsKey: "org1,org2",
loginAttributePathKey: "new-login-attr-path",
idTokenAttributeNameKey: "new-id-token-attr-name",
nameAttributePathKey: "new-name-attr-path",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
TeamsUrl: "https://host/v2/users",
EmailAttributePath: "new-email-attr-path",
EmailAttributeName: "new-email-attr-name",
GroupsAttributePath: "new-group-attr-path",
TeamIdsAttributePath: "new-team-ids-attr-path",
Extra: map[string]string{
teamIdsKey: "team1,team2",
allowedOrganizationsKey: "org1,org2",
loginAttributePathKey: "new-login-attr-path",
idTokenAttributeNameKey: "new-id-token-attr-name",
nameAttributePathKey: "new-name-attr-path",
},
},
expectedTeamsUrl: "https://host/v2/users",
expectedEmailAttributeName: "new-email-attr-name",
expectedEmailAttributePath: "new-email-attr-path",
expectedGroupsAttributePath: "new-group-attr-path",
expectedTeamIdsAttributePath: "new-team-ids-attr-path",
expectedTeamIds: []string{"team1", "team2"},
expectedAllowedOrganizations: []string{"org1", "org2"},
expectedLoginAttributePath: "new-login-attr-path",
expectedIdTokenAttributeName: "new-id-token-attr-name",
expectedNameAttributePath: "new-name-attr-path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewGenericOAuthProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedTeamsUrl, s.teamsUrl)
require.EqualValues(t, tc.expectedEmailAttributeName, s.emailAttributeName)
require.EqualValues(t, tc.expectedEmailAttributePath, s.emailAttributePath)
require.EqualValues(t, tc.expectedGroupsAttributePath, s.groupsAttributePath)
require.EqualValues(t, tc.expectedTeamIdsAttributePath, s.teamIdsAttributePath)
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedLoginAttributePath, s.loginAttributePath)
require.EqualValues(t, tc.expectedIdTokenAttributeName, s.idTokenAttributeName)
require.EqualValues(t, tc.expectedNameAttributePath, s.nameAttributePath)
})
}
}

View File

@@ -57,9 +57,8 @@ func NewGitHubProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings sso
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted)
config := createOAuthConfig(info, cfg, social.GitHubProviderName)
provider := &SocialGithub{
SocialBase: newSocialBase(social.GitHubProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.GitHubProviderName, info, features, cfg),
teamIds: teamIds,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
}
@@ -86,7 +85,37 @@ func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSetti
return err
}
// add specific validation rules for Github
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted)
if len(teamIdsSplitted) != len(teamIds) {
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
return ssosettings.ErrInvalidSettings.Errorf("Failed to parse team ids. Team ids must be a list of numbers.")
}
return nil
}
func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
teamIdsSplitted := util.SplitString(newInfo.Extra[teamIdsKey])
teamIds := mustInts(teamIdsSplitted)
if len(teamIdsSplitted) != len(teamIds) {
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GitHubProviderName, newInfo, s.features, s.cfg)
s.teamIds = teamIds
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil
}

View File

@@ -407,6 +407,7 @@ func TestSocialGitHub_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -427,6 +428,14 @@ func TestSocialGitHub_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/github",
},
},
{
name: "fails if settings contain invalid values",
@@ -446,6 +455,11 @@ func TestSocialGitHub_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/github",
},
},
}
@@ -456,10 +470,67 @@ func TestSocialGitHub_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestGitHub_Reload_ExtraFields(t *testing.T) {
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
expectedTeamIds []int
}{
{
name: "successfully reloads the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
"team_ids": "",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
"team_ids": "123,456",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
AuthStyle: "inheader",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
"force_use_graph_api": "false",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
expectedTeamIds: []int{123, 456},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := NewGitHubProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
})
}
}

View File

@@ -53,9 +53,8 @@ type userData struct {
}
func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab {
config := createOAuthConfig(info, cfg, social.GitlabProviderName)
provider := &SocialGitlab{
SocialBase: newSocialBase(social.GitlabProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.GitlabProviderName, info, features, cfg),
}
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -76,7 +75,19 @@ func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSetti
return err
}
// add specific validation rules for Gitlab
return nil
}
func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GitlabProviderName, newInfo, s.features, s.cfg)
return nil
}

View File

@@ -164,7 +164,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
for _, test := range tests {
provider.info.RoleAttributePath = test.RoleAttributePath
provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin
provider.autoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole)
provider.cfg.AutoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole)
provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict
provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync
@@ -525,6 +525,7 @@ func TestSocialGitlab_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -545,6 +546,14 @@ func TestSocialGitlab_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/gitlab",
},
},
{
name: "fails if settings contain invalid values",
@@ -564,6 +573,11 @@ func TestSocialGitlab_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/gitlab",
},
},
}
@@ -574,10 +588,12 @@ func TestSocialGitlab_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}

View File

@@ -39,9 +39,8 @@ type googleUserData struct {
}
func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle {
config := createOAuthConfig(info, cfg, social.GoogleProviderName)
provider := &SocialGoogle{
SocialBase: newSocialBase(social.GoogleProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.GoogleProviderName, info, features, cfg),
}
if strings.HasPrefix(info.ApiUrl, legacyAPIURL) {
@@ -71,6 +70,24 @@ func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSetti
return nil
}
func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
if strings.HasPrefix(newInfo.ApiUrl, legacyAPIURL) {
s.log.Warn("Using legacy Google API URL, please update your configuration")
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GoogleProviderName, newInfo, s.features, s.cfg)
return nil
}
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
info := s.GetOAuthInfo()
@@ -225,7 +242,7 @@ type googleGroupResp struct {
}
func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) {
s.log.Debug("Retrieving groups", "scopes", s.SocialBase.Config.Scopes)
s.log.Debug("Retrieving groups", "scopes", s.Config.Scopes)
if !slices.Contains(s.Scopes, googleIAMScope) {
return nil, nil
}

View File

@@ -730,6 +730,7 @@ func TestSocialGoogle_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -750,6 +751,14 @@ func TestSocialGoogle_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/google",
},
},
{
name: "fails if settings contain invalid values",
@@ -769,6 +778,11 @@ func TestSocialGoogle_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/google",
},
},
}
@@ -779,10 +793,12 @@ func TestSocialGoogle_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}

View File

@@ -39,9 +39,8 @@ func NewGrafanaComProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings
info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token"
info.AuthStyle = "inheader"
config := createOAuthConfig(info, cfg, social.GrafanaComProviderName)
provider := &SocialGrafanaCom{
SocialBase: newSocialBase(social.GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.GrafanaComProviderName, info, features, cfg),
url: cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
}
@@ -69,6 +68,28 @@ func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOS
return nil
}
func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
// Override necessary settings
newInfo.AuthUrl = s.cfg.GrafanaComURL + "/oauth2/authorize"
newInfo.TokenUrl = s.cfg.GrafanaComURL + "/api/oauth2/token"
newInfo.AuthStyle = "inheader"
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.GrafanaComProviderName, newInfo, s.features, s.cfg)
s.url = s.cfg.GrafanaComURL
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
return nil
}
func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool {
return true
}

View File

@@ -7,6 +7,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt"
@@ -200,6 +201,7 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -219,6 +221,19 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
AuthUrl: GrafanaComURL + "/oauth2/authorize",
TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: GrafanaComURL + "/oauth2/authorize",
TokenURL: GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: "/login/grafana_com",
},
},
{
@@ -243,6 +258,16 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: GrafanaComURL + "/oauth2/authorize",
TokenURL: GrafanaComURL + "/api/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: "/login/grafana_com",
},
},
}
@@ -256,10 +281,68 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}
func TestSocialGrafanaCom_Reload_ExtraFields(t *testing.T) {
const GrafanaComURL = "http://localhost:3000"
testCases := []struct {
name string
settings ssoModels.SSOSettings
info *social.OAuthInfo
expectError bool
expectedInfo *social.OAuthInfo
expectedAllowedOrganizations []string
}{
{
name: "successfully reloads the allowed organizations when they are set in the settings",
info: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Extra: map[string]string{
"allowed_organizations": "previous",
},
},
settings: ssoModels.SSOSettings{
Settings: map[string]any{
"allowed_organizations": "uuid-1234,uuid-5678",
},
},
expectedInfo: &social.OAuthInfo{
ClientId: "new-client-id",
ClientSecret: "new-client-secret",
Name: "a-new-name",
AuthUrl: GrafanaComURL + "/oauth2/authorize",
TokenUrl: GrafanaComURL + "/api/oauth2/token",
AuthStyle: "inheader",
Extra: map[string]string{
"allowed_organizations": "uuid-1234,uuid-5678",
},
},
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg := &setting.Cfg{
GrafanaComURL: GrafanaComURL,
}
s := NewGrafanaComProvider(tc.info, cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
err := s.Reload(context.Background(), tc.settings)
require.NoError(t, err)
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
})
}
}

View File

@@ -45,13 +45,12 @@ type OktaClaims struct {
}
func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta {
config := createOAuthConfig(info, cfg, social.OktaProviderName)
provider := &SocialOkta{
SocialBase: newSocialBase(social.OktaProviderName, config, info, cfg.AutoAssignOrgRole, features),
SocialBase: newSocialBase(social.OktaProviderName, info, features, cfg),
}
if info.UseRefreshToken {
appendUniqueScope(config, social.OfflineAccessScope)
appendUniqueScope(provider.Config, social.OfflineAccessScope)
}
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@@ -77,6 +76,23 @@ func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSetting
return nil
}
func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.reloadMutex.Lock()
defer s.reloadMutex.Unlock()
s.SocialBase = newSocialBase(social.OktaProviderName, newInfo, s.features, s.cfg)
if newInfo.UseRefreshToken {
appendUniqueScope(s.Config, social.OfflineAccessScope)
}
return nil
}
func (claims *OktaClaims) extractEmail() string {
if claims.Email == "" && claims.PreferredUsername != "" {
return claims.PreferredUsername

View File

@@ -198,6 +198,7 @@ func TestSocialOkta_Reload(t *testing.T) {
settings ssoModels.SSOSettings
expectError bool
expectedInfo *social.OAuthInfo
expectedConfig *oauth2.Config
}{
{
name: "SSO provider successfully updated",
@@ -218,6 +219,14 @@ func TestSocialOkta_Reload(t *testing.T) {
ClientSecret: "new-client-secret",
AuthUrl: "some-new-url",
},
expectedConfig: &oauth2.Config{
ClientID: "new-client-id",
ClientSecret: "new-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "some-new-url",
},
RedirectURL: "/login/okta",
},
},
{
name: "fails if settings contain invalid values",
@@ -237,6 +246,11 @@ func TestSocialOkta_Reload(t *testing.T) {
ClientId: "client-id",
ClientSecret: "client-secret",
},
expectedConfig: &oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "/login/okta",
},
},
}
@@ -247,10 +261,12 @@ func TestSocialOkta_Reload(t *testing.T) {
err := s.Reload(context.Background(), tc.settings)
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
return
}
require.NoError(t, err)
require.EqualValues(t, tc.expectedInfo, s.info)
require.EqualValues(t, tc.expectedConfig, s.Config)
})
}
}

View File

@@ -3,7 +3,6 @@ package connectors
import (
"bytes"
"compress/zlib"
"context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -21,32 +20,31 @@ import (
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/ssosettings"
ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models"
"github.com/grafana/grafana/pkg/setting"
)
type SocialBase struct {
*oauth2.Config
info *social.OAuthInfo
infoMutex sync.RWMutex
cfg *setting.Cfg
reloadMutex sync.RWMutex
log log.Logger
autoAssignOrgRole string
features featuremgmt.FeatureToggles
}
func newSocialBase(name string,
config *oauth2.Config,
info *social.OAuthInfo,
autoAssignOrgRole string,
features featuremgmt.FeatureToggles,
cfg *setting.Cfg,
) *SocialBase {
logger := log.New("oauth." + name)
return &SocialBase{
Config: config,
Config: createOAuthConfig(info, cfg, name),
info: info,
log: logger,
autoAssignOrgRole: autoAssignOrgRole,
features: features,
cfg: cfg,
}
}
@@ -60,7 +58,7 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin))
bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup))
bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains))
bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole))
bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.cfg.AutoAssignOrgRole))
bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath))
bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict))
bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync))
@@ -76,26 +74,12 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
}
func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo {
s.infoMutex.RLock()
defer s.infoMutex.RUnlock()
s.reloadMutex.RLock()
defer s.reloadMutex.RUnlock()
return s.info
}
func (s *SocialBase) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
if err != nil {
return fmt.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
}
s.infoMutex.Lock()
defer s.infoMutex.Unlock()
s.info = info
return nil
}
func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) {
if s.info.RoleAttributePath == "" {
if s.info.RoleAttributeStrict {
@@ -145,9 +129,9 @@ func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType,
// defaultRole returns the default role for the user based on the autoAssignOrgRole setting
// if legacy is enabled "" is returned indicating the previous role assignment is used.
func (s *SocialBase) defaultRole() org.RoleType {
if s.autoAssignOrgRole != "" {
if s.cfg.AutoAssignOrgRole != "" {
s.log.Debug("No role found, returning default.")
return org.RoleType(s.autoAssignOrgRole)
return org.RoleType(s.cfg.AutoAssignOrgRole)
}
// should never happen

View File

@@ -140,18 +140,8 @@ func ProvideService(
}
for name := range socialService.GetOAuthProviders() {
oauthCfg := socialService.GetOAuthInfoProvider(name)
if oauthCfg != nil && oauthCfg.Enabled {
clientName := authn.ClientWithPrefix(name)
connector, errConnector := socialService.GetConnector(name)
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
if errConnector != nil || errHTTPClient != nil {
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
} else {
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
}
}
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthTokenService, socialService))
}
// FIXME (jguer): move to User package

View File

@@ -8,7 +8,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
@@ -40,6 +39,9 @@ const (
)
var (
errOAuthClientDisabled = errutil.BadRequest("auth.oauth.disabled", errutil.WithPublicMessage("OAuth client is disabled"))
errOAuthInternal = errutil.Internal("auth.oauth.internal", errutil.WithPublicMessage("An internal error occurred in the OAuth client"))
errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
@@ -62,24 +64,25 @@ var _ authn.LogoutClient = new(OAuth)
var _ authn.RedirectClient = new(OAuth)
func ProvideOAuth(
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
name string, cfg *setting.Cfg, oauthService oauthtoken.OAuthTokenService,
socialService social.Service,
) *OAuth {
providerName := strings.TrimPrefix(name, "auth.client.")
return &OAuth{
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
name, fmt.Sprintf("oauth_%s", providerName), providerName,
log.New(name), cfg, oauthService, socialService,
}
}
type OAuth struct {
name string
moduleName string
providerName string
log log.Logger
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector
httpClient *http.Client
oauthService oauthtoken.OAuthTokenService
socialService social.Service
}
func (c *OAuth) Name() string {
@@ -88,6 +91,12 @@ func (c *OAuth) Name() string {
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
// get hashed state stored in cookie
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
if err != nil {
@@ -99,7 +108,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
}
// get state returned by the idp and hash it
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, oauthCfg.ClientSecret)
// compare the state returned by idp against the one we stored in cookie
if stateQuery != stateCookie.Value {
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
@@ -107,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
var opts []oauth2.AuthCodeOption
// if pkce is enabled for client validate we have the cookie and set it as url param
if c.oauthCfg.UsePKCE {
if oauthCfg.UsePKCE {
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
if err != nil {
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
@@ -115,15 +124,21 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
connector, errConnector := c.socialService.GetConnector(c.providerName)
httpClient, errHTTPClient := c.socialService.GetOAuthHttpClient(c.providerName)
if errConnector != nil || errHTTPClient != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth client: %w", c.name, errors.Join(errConnector, errHTTPClient))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// exchange auth code to a valid token
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
token, err := connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
if err != nil {
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
}
token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
userInfo, err := connector.UserInfo(ctx, connector.Client(clientCtx, token), token)
if err != nil {
var sErr *connectors.SocialError
if errors.As(err, &sErr) {
@@ -136,7 +151,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
}
if !c.connector.IsEmailAllowed(userInfo.Email) {
if !connector.IsEmailAllowed(userInfo.Email) {
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
}
@@ -167,7 +182,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
SyncTeams: true,
FetchSyncedUser: true,
SyncPermissions: true,
AllowSignUp: c.connector.IsSignupAllowed(),
AllowSignUp: connector.IsSignupAllowed(),
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
SyncOrgRoles: len(orgRoles) > 0,
LookUpParams: lookupParams,
@@ -178,12 +193,17 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
var opts []oauth2.AuthCodeOption
if c.oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
}
if oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, oauthCfg.HostedDomain))
}
var plainPKCE string
if c.oauthCfg.UsePKCE {
if oauthCfg.UsePKCE {
verifier, err := genPKCECodeVerifier()
if err != nil {
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
@@ -193,13 +213,18 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
}
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, oauthCfg.ClientSecret)
if err != nil {
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
}
connector, err := c.socialService.GetConnector(c.providerName)
if err != nil {
return nil, errOAuthInternal.Errorf("failed to get %s oauth connector: %w", c.name, err)
}
return &authn.Redirect{
URL: c.connector.AuthCodeURL(state, opts...),
URL: connector.AuthCodeURL(state, opts...),
Extra: map[string]string{
authn.KeyOAuthState: hashedSate,
authn.KeyOAuthPKCE: plainPKCE,
@@ -215,19 +240,25 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
}
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
if redirctURL == "" {
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
if !oauthCfg.Enabled {
c.log.FromContext(ctx).Debug("OAuth client is disabled")
return nil, false
}
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
if redirectURL == "" {
c.log.FromContext(ctx).Debug("No signout redirect url configured")
return nil, false
}
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
if isOICDLogout(redirectURL) && token != nil && token.Valid() {
if idToken, ok := token.Extra("id_token").(string); ok {
redirctURL = withIDTokenHint(redirctURL, idToken)
redirectURL = withIDTokenHint(redirectURL, idToken)
}
}
return &authn.Redirect{URL: redirctURL}, true
return &authn.Redirect{URL: redirectURL}, true
}
// genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
@@ -46,17 +47,23 @@ func TestOAuth_Authenticate(t *testing.T) {
{
desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
addStateCookie: true,
stateCookieValue: "",
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when the client is not enabled",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: false},
expectedErr: errOAuthClientDisabled,
},
{
desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{
@@ -64,7 +71,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState,
@@ -76,7 +83,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE,
@@ -88,7 +95,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -103,7 +110,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -119,7 +126,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
@@ -157,7 +164,7 @@ func TestOAuth_Authenticate(t *testing.T) {
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
allowInsecureTakeover: true,
addStateCookie: true,
stateCookieValue: "some-state",
@@ -211,12 +218,18 @@ func TestOAuth_Authenticate(t *testing.T) {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
}, nil, nil)
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc)
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
@@ -256,21 +269,27 @@ func TestOAuth_RedirectURL(t *testing.T) {
tests := []testCase{
{
desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should return error if the client is not enabled",
oauthCfg: &social.OAuthInfo{Enabled: false},
authCodeUrlCalled: false,
expectedErr: errOAuthClientDisabled,
},
}
for _, tt := range tests {
@@ -279,13 +298,18 @@ func TestOAuth_RedirectURL(t *testing.T) {
authCodeUrlCalled = false
)
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
}, nil, nil)
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), nil, fakeSocialSvc)
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
@@ -321,12 +345,17 @@ func TestOAuth_Logout(t *testing.T) {
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{},
},
{
desc: "should not return redirect url when client is not enabled",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{Enabled: false},
},
{
desc: "should return redirect url for globably configured redirect url",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
@@ -334,6 +363,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should return redirect url for client configured redirect url",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout",
},
expectedURL: "http://idp.com/logout",
@@ -345,6 +375,7 @@ func TestOAuth_Logout(t *testing.T) {
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp-2.com/logout",
},
expectedURL: "http://idp-2.com/logout",
@@ -354,6 +385,7 @@ func TestOAuth_Logout(t *testing.T) {
desc: "should add id token hint if oicd logout is configured and token is valid",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
},
expectedURL: "http://idp.com/logout",
@@ -387,7 +419,10 @@ func TestOAuth_Logout(t *testing.T) {
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc)
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})

View File

@@ -183,7 +183,7 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
settings.Provider = key
err := api.SSOSettingsService.Upsert(c.Req.Context(), settings)
err := api.SSOSettingsService.Upsert(c.Req.Context(), &settings)
if err != nil {
return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err)
}

View File

@@ -134,7 +134,7 @@ func TestSSOSettingsAPI_Update(t *testing.T) {
service := ssosettingstests.NewMockService(t)
if tt.expectedServiceCall {
service.On("Upsert", mock.Anything, settings).Return(tt.expectedError).Once()
service.On("Upsert", mock.Anything, &settings).Return(tt.expectedError).Once()
}
server := setupTests(t, service)

View File

@@ -84,7 +84,7 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err
return result, nil
}
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if settings.Provider == "" {
return ssosettings.ErrNotFound
}
@@ -110,13 +110,10 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin
}
_, err = sess.UseBool(isDeletedColumn).Update(updated, existing)
} else {
_, err = sess.Insert(&models.SSOSettings{
ID: uuid.New().String(),
Provider: settings.Provider,
Settings: settings.Settings,
Created: now,
Updated: now,
})
settings.ID = uuid.New().String()
settings.Created = now
settings.Updated = now
_, err = sess.Insert(settings)
}
return err

View File

@@ -105,7 +105,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
},
}
err := ssoSettingsStore.Upsert(context.Background(), settings)
err := ssoSettingsStore.Upsert(context.Background(), &settings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
@@ -143,7 +143,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
"client_secret": "this-is-a-new-secret",
},
}
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
@@ -181,7 +181,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
},
}
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
@@ -217,7 +217,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
"client_secret": "this-is-my-new-secret",
},
}
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
@@ -254,7 +254,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
},
}
err = ssoSettingsStore.Upsert(context.Background(), settings)
err = ssoSettingsStore.Upsert(context.Background(), &settings)
require.Error(t, err)
require.ErrorIs(t, err, ssosettings.ErrNotFound)
})

View File

@@ -9,7 +9,7 @@ import (
var (
ErrNotFound = errors.New("not found")
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("provider is invalid"))
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("settings field is invalid"))
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("settings.clientId cannot be empty"))
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("Provider is invalid"))
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("Settings field is invalid"))
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("ClientId cannot be empty"))
)

View File

@@ -28,7 +28,7 @@ type Service interface {
// GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted
GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error)
// Upsert creates or updates the SSO settings for a given provider
Upsert(ctx context.Context, settings models.SSOSettings) error
Upsert(ctx context.Context, settings *models.SSOSettings) error
// Delete deletes the SSO settings for a given provider (soft delete)
Delete(ctx context.Context, provider string) error
// Patch updates the specified SSO settings (key-value pairs) for a given provider
@@ -62,6 +62,6 @@ type FallbackStrategy interface {
type Store interface {
Get(ctx context.Context, provider string) (*models.SSOSettings, error)
List(ctx context.Context) ([]*models.SSOSettings, error)
Upsert(ctx context.Context, settings models.SSOSettings) error
Upsert(ctx context.Context, settings *models.SSOSettings) error
Delete(ctx context.Context, provider string) error
}

View File

@@ -147,7 +147,7 @@ func (s *SSOSettingsService) ListWithRedactedSecrets(ctx context.Context) ([]*mo
return storeSettings, nil
}
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
func (s *SSOSettingsService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if !isProviderConfigurable(settings.Provider) {
return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider)
}
@@ -157,7 +157,7 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider)
}
err := social.Validate(ctx, settings)
err := social.Validate(ctx, *settings)
if err != nil {
return err
}
@@ -167,6 +167,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
return err
}
secrets := collectSecrets(settings, storedSettings)
settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings)
if err != nil {
return err
@@ -178,7 +180,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
}
go func() {
err = social.Reload(context.Background(), settings)
settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets)
err = social.Reload(context.Background(), *settings)
if err != nil {
s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
}
@@ -249,7 +252,7 @@ func (s *SSOSettingsService) getFallbackStrategyFor(provider string) (ssosetting
func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) {
result := make(map[string]any)
for k, v := range settings {
if isSecret(k) {
if isSecret(k) && v != "" {
strValue, ok := v.(string)
if !ok {
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
@@ -315,20 +318,24 @@ func (s *SSOSettingsService) doReload(ctx context.Context) {
// mergeSSOSettings merges the settings from the database with the system settings
// Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings)
// and the rest of the settings are loaded from the system settings
// and the rest of the settings have to be loaded from the system settings
func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings {
if dbSettings == nil {
s.logger.Debug("No SSO Settings found in the database, using system settings")
return systemSettings
}
s.logger.Debug("Merging SSO Settings", "dbSettings", dbSettings.Settings, "systemSettings", systemSettings.Settings)
s.logger.Debug("Merging SSO Settings", "dbSettings", removeSecrets(dbSettings.Settings), "systemSettings", removeSecrets(systemSettings.Settings))
finalSettings := mergeSettings(dbSettings.Settings, systemSettings.Settings)
result := &models.SSOSettings{
Provider: dbSettings.Provider,
Source: dbSettings.Source,
Settings: mergeSettings(dbSettings.Settings, systemSettings.Settings),
Created: dbSettings.Created,
Updated: dbSettings.Updated,
}
dbSettings.Settings = finalSettings
return dbSettings
return result
}
func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
@@ -358,6 +365,22 @@ func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[st
return settings, nil
}
// removeSecrets removes all the secrets from the map and replaces them with a redacted password
// and returns a new map
func removeSecrets(settings map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range settings {
if isSecret(k) {
result[k] = setting.RedactedPassword
continue
}
result[k] = v
}
return result
}
// mergeSettings merges two maps in a way that the values from the first map are preserved
// and the values from the second map are added only if they don't exist in the first map
func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any {
settings := make(map[string]any)
@@ -374,6 +397,32 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
return settings
}
// collectSecrets collects all the secrets from the request and the currently stored settings
// and returns a new map
func collectSecrets(settings *models.SSOSettings, storedSettings *models.SSOSettings) map[string]any {
secrets := map[string]any{}
for k, v := range settings.Settings {
if isSecret(k) {
if isNewSecretValue(v.(string)) {
secrets[k] = v.(string) // use the new value
continue
}
secrets[k] = storedSettings.Settings[k] // keep the currently stored value
}
}
return secrets
}
func overrideMaps(maps ...map[string]any) map[string]any {
result := make(map[string]any)
for _, m := range maps {
for k, v := range m {
result[k] = v
}
}
return result
}
func isSecret(fieldName string) bool {
secretFieldPatterns := []string{"secret"}

View File

@@ -5,7 +5,10 @@ import (
"encoding/base64"
"errors"
"fmt"
"maps"
"sync"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -772,16 +775,50 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
},
IsDeleted: false,
}
var wg sync.WaitGroup
wg.Add(1)
reloadable := ssosettingstests.NewMockReloadable(t)
reloadable.On("Validate", mock.Anything, settings).Return(nil)
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe()
reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool {
wg.Done()
return settings.Provider == provider &&
settings.ID == "someid" &&
maps.Equal(settings.Settings, map[string]any{
"client_id": "client-id",
"client_secret": "client-secret",
"enabled": true,
})
})).Return(nil).Once()
env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
env.secrets.On("Decrypt", mock.Anything, []byte("encrypted-current-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings)
env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
currentTime := time.Now()
settings.ID = "someid"
settings.Created = currentTime
settings.Updated = currentTime
env.store.ActualSSOSettings = *settings
return nil
}
env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) {
return &models.SSOSettings{
ID: "someid",
Provider: provider,
Settings: map[string]any{
"client_secret": base64.RawStdEncoding.EncodeToString([]byte("encrypted-current-client-secret")),
},
}, nil
}
err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err)
// Wait for the goroutine first to assert the Reload call
wg.Wait()
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
require.EqualValues(t, settings, env.store.ActualSSOSettings)
})
@@ -790,7 +827,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env := setupTestEnv(t)
provider := social.GrafanaComProviderName
settings := models.SSOSettings{
settings := &models.SSOSettings{
Provider: provider,
Settings: map[string]any{
"client_id": "client-id",
@@ -811,7 +848,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env := setupTestEnv(t)
provider := social.AzureADProviderName
settings := models.SSOSettings{
settings := &models.SSOSettings{
Provider: provider,
Settings: map[string]any{
"client_id": "client-id",
@@ -847,14 +884,14 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed"))
env.reloadables[provider] = reloadable
err := env.service.Upsert(context.Background(), settings)
err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err)
})
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
env := setupTestEnv(t)
settings := models.SSOSettings{
settings := &models.SSOSettings{
Provider: social.AzureADProviderName,
Settings: map[string]any{
"client_id": "client-id",
@@ -889,7 +926,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
err := env.service.Upsert(context.Background(), settings)
err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err)
})
@@ -921,7 +958,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings)
err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err)
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret"))
@@ -950,11 +987,11 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
return &models.SSOSettings{}, nil
}
env.store.UpsertFn = func(ctx context.Context, settings models.SSOSettings) error {
env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
return errors.New("failed to upsert settings")
}
err := env.service.Upsert(context.Background(), settings)
err := env.service.Upsert(context.Background(), &settings)
require.Error(t, err)
})
@@ -978,7 +1015,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
env.reloadables[provider] = reloadable
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings)
err := env.service.Upsert(context.Background(), &settings)
require.NoError(t, err)
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
@@ -1098,6 +1135,22 @@ func TestSSOSettingsService_decryptSecrets(t *testing.T) {
"other_secret": "decrypted-other-secret",
},
},
{
name: "should not decrypt when a secret is empty",
setup: func(env testEnv) {
env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once()
},
settings: map[string]any{
"enabled": true,
"client_secret": "",
"other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")),
},
want: map[string]any{
"enabled": true,
"client_secret": "",
"other_secret": "decrypted-other-secret",
},
},
{
name: "should return an error if data is not a string",
settings: map[string]any{

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.38.0. DO NOT EDIT.
// Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.38.0. DO NOT EDIT.
// Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests
@@ -183,7 +183,7 @@ func (_m *MockService) Reload(ctx context.Context, provider string) {
}
// Upsert provides a mock function with given fields: ctx, settings
func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error {
func (_m *MockService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
ret := _m.Called(ctx, settings)
if len(ret) == 0 {
@@ -191,7 +191,7 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings)
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
r0 = rf(ctx, settings)
} else {
r0 = ret.Error(0)

View File

@@ -17,7 +17,7 @@ type FakeStore struct {
ActualSSOSettings models.SSOSettings
GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error)
UpsertFn func(ctx context.Context, settings models.SSOSettings) error
UpsertFn func(ctx context.Context, settings *models.SSOSettings) error
}
func NewFakeStore() *FakeStore {
@@ -35,12 +35,12 @@ func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
return f.ExpectedSSOSettings, f.ExpectedError
}
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
func (f *FakeStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
if f.UpsertFn != nil {
return f.UpsertFn(ctx, settings)
}
f.ActualSSOSettings = settings
f.ActualSSOSettings = *settings
return f.ExpectedError
}

View File

@@ -1,4 +1,4 @@
// Code generated by mockery v2.39.2. DO NOT EDIT.
// Code generated by mockery v2.40.1. DO NOT EDIT.
package ssosettingstests
@@ -93,7 +93,7 @@ func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
}
// Upsert provides a mock function with given fields: ctx, settings
func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
func (_m *MockStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
ret := _m.Called(ctx, settings)
if len(ret) == 0 {
@@ -101,7 +101,7 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
r0 = rf(ctx, settings)
} else {
r0 = ret.Error(0)

View File

@@ -2,6 +2,7 @@ package strategies
import (
"context"
"maps"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/ssosettings"
@@ -31,7 +32,10 @@ func (s *OAuthStrategy) IsMatch(provider string) bool {
}
func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) {
return s.settingsByProvider[provider], nil
providerConfig := s.settingsByProvider[provider]
result := make(map[string]any, len(providerConfig))
maps.Copy(result, providerConfig)
return result, nil
}
func (s *OAuthStrategy) loadAllSettings() {