mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AuthN: Support reloading SSO config after the sso settings have changed (#80734)
* Add AuthNSvc reload handling * Working, need to add test * Remove commented out code * Add Reload implementation to connectors * Align and add tests, refactor * Add more tests, linting * Add extra checks + tests to oauth client * Clean up based on reviews * Move config instantiation into newSocialBase * Use specific error
This commit is contained in:
@@ -73,16 +73,15 @@ type keySetJWKS struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD {
|
func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles, cache remotecache.CacheStorage) *SocialAzureAD {
|
||||||
config := createOAuthConfig(info, cfg, social.AzureADProviderName)
|
|
||||||
provider := &SocialAzureAD{
|
provider := &SocialAzureAD{
|
||||||
SocialBase: newSocialBase(social.AzureADProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.AzureADProviderName, info, features, cfg),
|
||||||
cache: cache,
|
cache: cache,
|
||||||
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
||||||
forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false),
|
forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false),
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.UseRefreshToken {
|
if info.UseRefreshToken {
|
||||||
appendUniqueScope(config, social.OfflineAccessScope)
|
appendUniqueScope(provider.Config, social.OfflineAccessScope)
|
||||||
}
|
}
|
||||||
|
|
||||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
||||||
@@ -164,6 +163,27 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.AzureADProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
|
if newInfo.UseRefreshToken {
|
||||||
|
appendUniqueScope(s.Config, social.OfflineAccessScope)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
|
||||||
|
s.forceUseGraphAPI = MustBool(newInfo.Extra[forceUseGraphAPIKey], false)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error {
|
func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1048,11 +1048,12 @@ func TestSocialAzureAD_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialAzureAD_Reload(t *testing.T) {
|
func TestSocialAzureAD_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -1073,6 +1074,14 @@ func TestSocialAzureAD_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/azuread",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -1092,6 +1101,11 @@ func TestSocialAzureAD_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/azuread",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1102,10 +1116,65 @@ func TestSocialAzureAD_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
settings ssoModels.SSOSettings
|
||||||
|
info *social.OAuthInfo
|
||||||
|
expectError bool
|
||||||
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedAllowedOrganizations []string
|
||||||
|
expectedForceUseGraphApi bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successfully reloads the settings",
|
||||||
|
info: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "previous",
|
||||||
|
"force_use_graph_api": "true",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
settings: ssoModels.SSOSettings{
|
||||||
|
Settings: map[string]any{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
"force_use_graph_api": "false",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInfo: &social.OAuthInfo{
|
||||||
|
ClientId: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Name: "a-new-name",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
"force_use_graph_api": "false",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
|
||||||
|
expectedForceUseGraphApi: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := NewAzureADProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||||
|
|
||||||
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
|
||||||
|
require.EqualValues(t, tc.expectedForceUseGraphApi, s.forceUseGraphAPI)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,9 +46,8 @@ type SocialGenericOAuth struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth {
|
func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGenericOAuth {
|
||||||
config := createOAuthConfig(info, cfg, social.GenericOAuthProviderName)
|
|
||||||
provider := &SocialGenericOAuth{
|
provider := &SocialGenericOAuth{
|
||||||
SocialBase: newSocialBase(social.GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.GenericOAuthProviderName, info, features, cfg),
|
||||||
teamsUrl: info.TeamsUrl,
|
teamsUrl: info.TeamsUrl,
|
||||||
emailAttributeName: info.EmailAttributeName,
|
emailAttributeName: info.EmailAttributeName,
|
||||||
emailAttributePath: info.EmailAttributePath,
|
emailAttributePath: info.EmailAttributePath,
|
||||||
@@ -84,6 +83,31 @@ func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SS
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.GenericOAuthProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
|
s.teamsUrl = newInfo.TeamsUrl
|
||||||
|
s.emailAttributeName = newInfo.EmailAttributeName
|
||||||
|
s.emailAttributePath = newInfo.EmailAttributePath
|
||||||
|
s.nameAttributePath = newInfo.Extra[nameAttributePathKey]
|
||||||
|
s.groupsAttributePath = newInfo.GroupsAttributePath
|
||||||
|
s.loginAttributePath = newInfo.Extra[loginAttributePathKey]
|
||||||
|
s.idTokenAttributeName = newInfo.Extra[idTokenAttributeNameKey]
|
||||||
|
s.teamIdsAttributePath = newInfo.TeamIdsAttributePath
|
||||||
|
s.teamIds = util.SplitString(newInfo.Extra[teamIdsKey])
|
||||||
|
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TODOD: remove this in the next PR and use the isGroupMember from social.go
|
// TODOD: remove this in the next PR and use the isGroupMember from social.go
|
||||||
func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
|
func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
|
||||||
info := s.GetOAuthInfo()
|
info := s.GetOAuthInfo()
|
||||||
|
|||||||
@@ -976,11 +976,12 @@ func TestSocialGenericOAuth_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialGenericOAuth_Reload(t *testing.T) {
|
func TestSocialGenericOAuth_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -1001,6 +1002,14 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/generic_oauth",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -1020,6 +1029,11 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/generic_oauth",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1030,10 +1044,116 @@ func TestSocialGenericOAuth_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericOAuth_Reload_ExtraFields(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
settings ssoModels.SSOSettings
|
||||||
|
info *social.OAuthInfo
|
||||||
|
expectError bool
|
||||||
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedTeamsUrl string
|
||||||
|
expectedEmailAttributeName string
|
||||||
|
expectedEmailAttributePath string
|
||||||
|
expectedNameAttributePath string
|
||||||
|
expectedGroupsAttributePath string
|
||||||
|
expectedLoginAttributePath string
|
||||||
|
expectedIdTokenAttributeName string
|
||||||
|
expectedTeamIdsAttributePath string
|
||||||
|
expectedTeamIds []string
|
||||||
|
expectedAllowedOrganizations []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successfully reloads the settings",
|
||||||
|
info: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
TeamsUrl: "https://host/users",
|
||||||
|
EmailAttributePath: "email-attr-path",
|
||||||
|
EmailAttributeName: "email-attr-name",
|
||||||
|
GroupsAttributePath: "groups-attr-path",
|
||||||
|
TeamIdsAttributePath: "team-ids-attr-path",
|
||||||
|
Extra: map[string]string{
|
||||||
|
teamIdsKey: "team1",
|
||||||
|
allowedOrganizationsKey: "org1",
|
||||||
|
loginAttributePathKey: "login-attr-path",
|
||||||
|
idTokenAttributeNameKey: "id-token-attr-name",
|
||||||
|
nameAttributePathKey: "name-attr-path",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
settings: ssoModels.SSOSettings{
|
||||||
|
Settings: map[string]any{
|
||||||
|
"client_id": "new-client-id",
|
||||||
|
"client_secret": "new-client-secret",
|
||||||
|
"teams_url": "https://host/v2/users",
|
||||||
|
"email_attribute_path": "new-email-attr-path",
|
||||||
|
"email_attribute_name": "new-email-attr-name",
|
||||||
|
"groups_attribute_path": "new-group-attr-path",
|
||||||
|
"team_ids_attribute_path": "new-team-ids-attr-path",
|
||||||
|
teamIdsKey: "team1,team2",
|
||||||
|
allowedOrganizationsKey: "org1,org2",
|
||||||
|
loginAttributePathKey: "new-login-attr-path",
|
||||||
|
idTokenAttributeNameKey: "new-id-token-attr-name",
|
||||||
|
nameAttributePathKey: "new-name-attr-path",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInfo: &social.OAuthInfo{
|
||||||
|
ClientId: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
TeamsUrl: "https://host/v2/users",
|
||||||
|
EmailAttributePath: "new-email-attr-path",
|
||||||
|
EmailAttributeName: "new-email-attr-name",
|
||||||
|
GroupsAttributePath: "new-group-attr-path",
|
||||||
|
TeamIdsAttributePath: "new-team-ids-attr-path",
|
||||||
|
Extra: map[string]string{
|
||||||
|
teamIdsKey: "team1,team2",
|
||||||
|
allowedOrganizationsKey: "org1,org2",
|
||||||
|
loginAttributePathKey: "new-login-attr-path",
|
||||||
|
idTokenAttributeNameKey: "new-id-token-attr-name",
|
||||||
|
nameAttributePathKey: "new-name-attr-path",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedTeamsUrl: "https://host/v2/users",
|
||||||
|
expectedEmailAttributeName: "new-email-attr-name",
|
||||||
|
expectedEmailAttributePath: "new-email-attr-path",
|
||||||
|
expectedGroupsAttributePath: "new-group-attr-path",
|
||||||
|
expectedTeamIdsAttributePath: "new-team-ids-attr-path",
|
||||||
|
expectedTeamIds: []string{"team1", "team2"},
|
||||||
|
expectedAllowedOrganizations: []string{"org1", "org2"},
|
||||||
|
expectedLoginAttributePath: "new-login-attr-path",
|
||||||
|
expectedIdTokenAttributeName: "new-id-token-attr-name",
|
||||||
|
expectedNameAttributePath: "new-name-attr-path",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := NewGenericOAuthProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
|
||||||
|
|
||||||
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
|
||||||
|
require.EqualValues(t, tc.expectedTeamsUrl, s.teamsUrl)
|
||||||
|
require.EqualValues(t, tc.expectedEmailAttributeName, s.emailAttributeName)
|
||||||
|
require.EqualValues(t, tc.expectedEmailAttributePath, s.emailAttributePath)
|
||||||
|
require.EqualValues(t, tc.expectedGroupsAttributePath, s.groupsAttributePath)
|
||||||
|
require.EqualValues(t, tc.expectedTeamIdsAttributePath, s.teamIdsAttributePath)
|
||||||
|
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
|
||||||
|
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
|
||||||
|
require.EqualValues(t, tc.expectedLoginAttributePath, s.loginAttributePath)
|
||||||
|
require.EqualValues(t, tc.expectedIdTokenAttributeName, s.idTokenAttributeName)
|
||||||
|
require.EqualValues(t, tc.expectedNameAttributePath, s.nameAttributePath)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,9 +57,8 @@ func NewGitHubProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings sso
|
|||||||
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
|
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
|
||||||
teamIds := mustInts(teamIdsSplitted)
|
teamIds := mustInts(teamIdsSplitted)
|
||||||
|
|
||||||
config := createOAuthConfig(info, cfg, social.GitHubProviderName)
|
|
||||||
provider := &SocialGithub{
|
provider := &SocialGithub{
|
||||||
SocialBase: newSocialBase(social.GitHubProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.GitHubProviderName, info, features, cfg),
|
||||||
teamIds: teamIds,
|
teamIds: teamIds,
|
||||||
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
||||||
}
|
}
|
||||||
@@ -86,7 +85,37 @@ func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSetti
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// add specific validation rules for Github
|
teamIdsSplitted := util.SplitString(info.Extra[teamIdsKey])
|
||||||
|
teamIds := mustInts(teamIdsSplitted)
|
||||||
|
|
||||||
|
if len(teamIdsSplitted) != len(teamIds) {
|
||||||
|
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("Failed to parse team ids. Team ids must be a list of numbers.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
teamIdsSplitted := util.SplitString(newInfo.Extra[teamIdsKey])
|
||||||
|
teamIds := mustInts(teamIdsSplitted)
|
||||||
|
|
||||||
|
if len(teamIdsSplitted) != len(teamIds) {
|
||||||
|
s.log.Warn("Failed to parse team ids. Team ids must be a list of numbers.", "teamIds", teamIdsSplitted)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.GitHubProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
|
s.teamIds = teamIds
|
||||||
|
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -402,11 +402,12 @@ func TestSocialGitHub_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialGitHub_Reload(t *testing.T) {
|
func TestSocialGitHub_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -427,6 +428,14 @@ func TestSocialGitHub_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/github",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -446,6 +455,11 @@ func TestSocialGitHub_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/github",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,10 +470,67 @@ func TestSocialGitHub_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHub_Reload_ExtraFields(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
settings ssoModels.SSOSettings
|
||||||
|
info *social.OAuthInfo
|
||||||
|
expectError bool
|
||||||
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedAllowedOrganizations []string
|
||||||
|
expectedTeamIds []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successfully reloads the settings",
|
||||||
|
info: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "previous",
|
||||||
|
"team_ids": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
settings: ssoModels.SSOSettings{
|
||||||
|
Settings: map[string]any{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
"team_ids": "123,456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInfo: &social.OAuthInfo{
|
||||||
|
ClientId: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Name: "a-new-name",
|
||||||
|
AuthStyle: "inheader",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
"force_use_graph_api": "false",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
|
||||||
|
expectedTeamIds: []int{123, 456},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := NewGitHubProvider(tc.info, setting.NewCfg(), &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
|
||||||
|
|
||||||
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
|
||||||
|
require.EqualValues(t, tc.expectedTeamIds, s.teamIds)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,9 +53,8 @@ type userData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab {
|
func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGitlab {
|
||||||
config := createOAuthConfig(info, cfg, social.GitlabProviderName)
|
|
||||||
provider := &SocialGitlab{
|
provider := &SocialGitlab{
|
||||||
SocialBase: newSocialBase(social.GitlabProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.GitlabProviderName, info, features, cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
||||||
@@ -76,7 +75,19 @@ func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSetti
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// add specific validation rules for Gitlab
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.GitlabProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
provider.info.RoleAttributePath = test.RoleAttributePath
|
provider.info.RoleAttributePath = test.RoleAttributePath
|
||||||
provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin
|
provider.info.AllowAssignGrafanaAdmin = test.Cfg.AllowAssignGrafanaAdmin
|
||||||
provider.autoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole)
|
provider.cfg.AutoAssignOrgRole = string(test.Cfg.AutoAssignOrgRole)
|
||||||
provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict
|
provider.info.RoleAttributeStrict = test.Cfg.RoleAttributeStrict
|
||||||
provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync
|
provider.info.SkipOrgRoleSync = test.Cfg.SkipOrgRoleSync
|
||||||
|
|
||||||
@@ -520,11 +520,12 @@ func TestSocialGitlab_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialGitlab_Reload(t *testing.T) {
|
func TestSocialGitlab_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -545,6 +546,14 @@ func TestSocialGitlab_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/gitlab",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -564,6 +573,11 @@ func TestSocialGitlab_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/gitlab",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -574,10 +588,12 @@ func TestSocialGitlab_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,9 +39,8 @@ type googleUserData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle {
|
func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialGoogle {
|
||||||
config := createOAuthConfig(info, cfg, social.GoogleProviderName)
|
|
||||||
provider := &SocialGoogle{
|
provider := &SocialGoogle{
|
||||||
SocialBase: newSocialBase(social.GoogleProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.GoogleProviderName, info, features, cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.ApiUrl, legacyAPIURL) {
|
if strings.HasPrefix(info.ApiUrl, legacyAPIURL) {
|
||||||
@@ -71,6 +70,24 @@ func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSetti
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(newInfo.ApiUrl, legacyAPIURL) {
|
||||||
|
s.log.Warn("Using legacy Google API URL, please update your configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.GoogleProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||||
info := s.GetOAuthInfo()
|
info := s.GetOAuthInfo()
|
||||||
|
|
||||||
@@ -225,7 +242,7 @@ type googleGroupResp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) {
|
func (s *SocialGoogle) retrieveGroups(ctx context.Context, client *http.Client, userData *googleUserData) ([]string, error) {
|
||||||
s.log.Debug("Retrieving groups", "scopes", s.SocialBase.Config.Scopes)
|
s.log.Debug("Retrieving groups", "scopes", s.Config.Scopes)
|
||||||
if !slices.Contains(s.Scopes, googleIAMScope) {
|
if !slices.Contains(s.Scopes, googleIAMScope) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -725,11 +725,12 @@ func TestSocialGoogle_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialGoogle_Reload(t *testing.T) {
|
func TestSocialGoogle_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -750,6 +751,14 @@ func TestSocialGoogle_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/google",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -769,6 +778,11 @@ func TestSocialGoogle_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/google",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -779,10 +793,12 @@ func TestSocialGoogle_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,9 +39,8 @@ func NewGrafanaComProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings
|
|||||||
info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token"
|
info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token"
|
||||||
info.AuthStyle = "inheader"
|
info.AuthStyle = "inheader"
|
||||||
|
|
||||||
config := createOAuthConfig(info, cfg, social.GrafanaComProviderName)
|
|
||||||
provider := &SocialGrafanaCom{
|
provider := &SocialGrafanaCom{
|
||||||
SocialBase: newSocialBase(social.GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.GrafanaComProviderName, info, features, cfg),
|
||||||
url: cfg.GrafanaComURL,
|
url: cfg.GrafanaComURL,
|
||||||
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]),
|
||||||
}
|
}
|
||||||
@@ -69,6 +68,28 @@ func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOS
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override necessary settings
|
||||||
|
newInfo.AuthUrl = s.cfg.GrafanaComURL + "/oauth2/authorize"
|
||||||
|
newInfo.TokenUrl = s.cfg.GrafanaComURL + "/api/oauth2/token"
|
||||||
|
newInfo.AuthStyle = "inheader"
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.GrafanaComProviderName, newInfo, s.features, s.cfg)
|
||||||
|
|
||||||
|
s.url = s.cfg.GrafanaComURL
|
||||||
|
s.allowedOrganizations = util.SplitString(newInfo.Extra[allowedOrganizationsKey])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool {
|
func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
@@ -195,11 +196,12 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
|
|||||||
const GrafanaComURL = "http://localhost:3000"
|
const GrafanaComURL = "http://localhost:3000"
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -219,6 +221,19 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
|
|||||||
ClientId: "new-client-id",
|
ClientId: "new-client-id",
|
||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
Name: "a-new-name",
|
Name: "a-new-name",
|
||||||
|
AuthUrl: GrafanaComURL + "/oauth2/authorize",
|
||||||
|
TokenUrl: GrafanaComURL + "/api/oauth2/token",
|
||||||
|
AuthStyle: "inheader",
|
||||||
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: GrafanaComURL + "/oauth2/authorize",
|
||||||
|
TokenURL: GrafanaComURL + "/api/oauth2/token",
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/grafana_com",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -243,6 +258,16 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
|
|||||||
TokenUrl: GrafanaComURL + "/api/oauth2/token",
|
TokenUrl: GrafanaComURL + "/api/oauth2/token",
|
||||||
AuthStyle: "inheader",
|
AuthStyle: "inheader",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: GrafanaComURL + "/oauth2/authorize",
|
||||||
|
TokenURL: GrafanaComURL + "/api/oauth2/token",
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/grafana_com",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,10 +281,68 @@ func TestSocialGrafanaCom_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSocialGrafanaCom_Reload_ExtraFields(t *testing.T) {
|
||||||
|
const GrafanaComURL = "http://localhost:3000"
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
settings ssoModels.SSOSettings
|
||||||
|
info *social.OAuthInfo
|
||||||
|
expectError bool
|
||||||
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedAllowedOrganizations []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successfully reloads the allowed organizations when they are set in the settings",
|
||||||
|
info: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "previous",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
settings: ssoModels.SSOSettings{
|
||||||
|
Settings: map[string]any{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedInfo: &social.OAuthInfo{
|
||||||
|
ClientId: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Name: "a-new-name",
|
||||||
|
AuthUrl: GrafanaComURL + "/oauth2/authorize",
|
||||||
|
TokenUrl: GrafanaComURL + "/api/oauth2/token",
|
||||||
|
AuthStyle: "inheader",
|
||||||
|
Extra: map[string]string{
|
||||||
|
"allowed_organizations": "uuid-1234,uuid-5678",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedAllowedOrganizations: []string{"uuid-1234", "uuid-5678"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
cfg := &setting.Cfg{
|
||||||
|
GrafanaComURL: GrafanaComURL,
|
||||||
|
}
|
||||||
|
s := NewGrafanaComProvider(tc.info, cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
|
||||||
|
|
||||||
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.EqualValues(t, tc.expectedAllowedOrganizations, s.allowedOrganizations)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,13 +45,12 @@ type OktaClaims struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta {
|
func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features featuremgmt.FeatureToggles) *SocialOkta {
|
||||||
config := createOAuthConfig(info, cfg, social.OktaProviderName)
|
|
||||||
provider := &SocialOkta{
|
provider := &SocialOkta{
|
||||||
SocialBase: newSocialBase(social.OktaProviderName, config, info, cfg.AutoAssignOrgRole, features),
|
SocialBase: newSocialBase(social.OktaProviderName, info, features, cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.UseRefreshToken {
|
if info.UseRefreshToken {
|
||||||
appendUniqueScope(config, social.OfflineAccessScope)
|
appendUniqueScope(provider.Config, social.OfflineAccessScope)
|
||||||
}
|
}
|
||||||
|
|
||||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
||||||
@@ -77,6 +76,23 @@ func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSetting
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
||||||
|
newInfo, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
||||||
|
if err != nil {
|
||||||
|
return ssosettings.ErrInvalidSettings.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.reloadMutex.Lock()
|
||||||
|
defer s.reloadMutex.Unlock()
|
||||||
|
|
||||||
|
s.SocialBase = newSocialBase(social.OktaProviderName, newInfo, s.features, s.cfg)
|
||||||
|
if newInfo.UseRefreshToken {
|
||||||
|
appendUniqueScope(s.Config, social.OfflineAccessScope)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (claims *OktaClaims) extractEmail() string {
|
func (claims *OktaClaims) extractEmail() string {
|
||||||
if claims.Email == "" && claims.PreferredUsername != "" {
|
if claims.Email == "" && claims.PreferredUsername != "" {
|
||||||
return claims.PreferredUsername
|
return claims.PreferredUsername
|
||||||
|
|||||||
@@ -193,11 +193,12 @@ func TestSocialOkta_Validate(t *testing.T) {
|
|||||||
|
|
||||||
func TestSocialOkta_Reload(t *testing.T) {
|
func TestSocialOkta_Reload(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
settings ssoModels.SSOSettings
|
settings ssoModels.SSOSettings
|
||||||
expectError bool
|
expectError bool
|
||||||
expectedInfo *social.OAuthInfo
|
expectedInfo *social.OAuthInfo
|
||||||
|
expectedConfig *oauth2.Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "SSO provider successfully updated",
|
name: "SSO provider successfully updated",
|
||||||
@@ -218,6 +219,14 @@ func TestSocialOkta_Reload(t *testing.T) {
|
|||||||
ClientSecret: "new-client-secret",
|
ClientSecret: "new-client-secret",
|
||||||
AuthUrl: "some-new-url",
|
AuthUrl: "some-new-url",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "new-client-id",
|
||||||
|
ClientSecret: "new-client-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "some-new-url",
|
||||||
|
},
|
||||||
|
RedirectURL: "/login/okta",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fails if settings contain invalid values",
|
name: "fails if settings contain invalid values",
|
||||||
@@ -237,6 +246,11 @@ func TestSocialOkta_Reload(t *testing.T) {
|
|||||||
ClientId: "client-id",
|
ClientId: "client-id",
|
||||||
ClientSecret: "client-secret",
|
ClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
expectedConfig: &oauth2.Config{
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
RedirectURL: "/login/okta",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,10 +261,12 @@ func TestSocialOkta_Reload(t *testing.T) {
|
|||||||
err := s.Reload(context.Background(), tc.settings)
|
err := s.Reload(context.Background(), tc.settings)
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
return
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.EqualValues(t, tc.expectedInfo, s.info)
|
require.EqualValues(t, tc.expectedInfo, s.info)
|
||||||
|
require.EqualValues(t, tc.expectedConfig, s.Config)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package connectors
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/zlib"
|
"compress/zlib"
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -21,32 +20,31 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||||
ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SocialBase struct {
|
type SocialBase struct {
|
||||||
*oauth2.Config
|
*oauth2.Config
|
||||||
info *social.OAuthInfo
|
info *social.OAuthInfo
|
||||||
infoMutex sync.RWMutex
|
cfg *setting.Cfg
|
||||||
log log.Logger
|
reloadMutex sync.RWMutex
|
||||||
autoAssignOrgRole string
|
log log.Logger
|
||||||
features featuremgmt.FeatureToggles
|
features featuremgmt.FeatureToggles
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSocialBase(name string,
|
func newSocialBase(name string,
|
||||||
config *oauth2.Config,
|
|
||||||
info *social.OAuthInfo,
|
info *social.OAuthInfo,
|
||||||
autoAssignOrgRole string,
|
|
||||||
features featuremgmt.FeatureToggles,
|
features featuremgmt.FeatureToggles,
|
||||||
|
cfg *setting.Cfg,
|
||||||
) *SocialBase {
|
) *SocialBase {
|
||||||
logger := log.New("oauth." + name)
|
logger := log.New("oauth." + name)
|
||||||
|
|
||||||
return &SocialBase{
|
return &SocialBase{
|
||||||
Config: config,
|
Config: createOAuthConfig(info, cfg, name),
|
||||||
info: info,
|
info: info,
|
||||||
log: logger,
|
log: logger,
|
||||||
autoAssignOrgRole: autoAssignOrgRole,
|
features: features,
|
||||||
features: features,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +58,7 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
|
|||||||
bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin))
|
bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.info.AllowAssignGrafanaAdmin))
|
||||||
bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup))
|
bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.info.AllowSignup))
|
||||||
bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains))
|
bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.info.AllowedDomains))
|
||||||
bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole))
|
bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.cfg.AutoAssignOrgRole))
|
||||||
bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath))
|
bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.info.RoleAttributePath))
|
||||||
bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict))
|
bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.info.RoleAttributeStrict))
|
||||||
bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync))
|
bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.info.SkipOrgRoleSync))
|
||||||
@@ -76,26 +74,12 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo {
|
func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo {
|
||||||
s.infoMutex.RLock()
|
s.reloadMutex.RLock()
|
||||||
defer s.infoMutex.RUnlock()
|
defer s.reloadMutex.RUnlock()
|
||||||
|
|
||||||
return s.info
|
return s.info
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialBase) Reload(ctx context.Context, settings ssoModels.SSOSettings) error {
|
|
||||||
info, err := CreateOAuthInfoFromKeyValues(settings.Settings)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.infoMutex.Lock()
|
|
||||||
defer s.infoMutex.Unlock()
|
|
||||||
|
|
||||||
s.info = info
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) {
|
func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) {
|
||||||
if s.info.RoleAttributePath == "" {
|
if s.info.RoleAttributePath == "" {
|
||||||
if s.info.RoleAttributeStrict {
|
if s.info.RoleAttributeStrict {
|
||||||
@@ -145,9 +129,9 @@ func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType,
|
|||||||
// defaultRole returns the default role for the user based on the autoAssignOrgRole setting
|
// defaultRole returns the default role for the user based on the autoAssignOrgRole setting
|
||||||
// if legacy is enabled "" is returned indicating the previous role assignment is used.
|
// if legacy is enabled "" is returned indicating the previous role assignment is used.
|
||||||
func (s *SocialBase) defaultRole() org.RoleType {
|
func (s *SocialBase) defaultRole() org.RoleType {
|
||||||
if s.autoAssignOrgRole != "" {
|
if s.cfg.AutoAssignOrgRole != "" {
|
||||||
s.log.Debug("No role found, returning default.")
|
s.log.Debug("No role found, returning default.")
|
||||||
return org.RoleType(s.autoAssignOrgRole)
|
return org.RoleType(s.cfg.AutoAssignOrgRole)
|
||||||
}
|
}
|
||||||
|
|
||||||
// should never happen
|
// should never happen
|
||||||
|
|||||||
@@ -140,18 +140,8 @@ func ProvideService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for name := range socialService.GetOAuthProviders() {
|
for name := range socialService.GetOAuthProviders() {
|
||||||
oauthCfg := socialService.GetOAuthInfoProvider(name)
|
clientName := authn.ClientWithPrefix(name)
|
||||||
if oauthCfg != nil && oauthCfg.Enabled {
|
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthTokenService, socialService))
|
||||||
clientName := authn.ClientWithPrefix(name)
|
|
||||||
|
|
||||||
connector, errConnector := socialService.GetConnector(name)
|
|
||||||
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
|
|
||||||
if errConnector != nil || errHTTPClient != nil {
|
|
||||||
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
|
|
||||||
} else {
|
|
||||||
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME (jguer): move to User package
|
// FIXME (jguer): move to User package
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -40,6 +39,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
errOAuthClientDisabled = errutil.BadRequest("auth.oauth.disabled", errutil.WithPublicMessage("OAuth client is disabled"))
|
||||||
|
errOAuthInternal = errutil.Internal("auth.oauth.internal", errutil.WithPublicMessage("An internal error occurred in the OAuth client"))
|
||||||
|
|
||||||
errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
|
errOAuthGenPKCE = errutil.Internal("auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
|
||||||
errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
|
errOAuthMissingPKCE = errutil.BadRequest("auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
|
||||||
|
|
||||||
@@ -62,24 +64,25 @@ var _ authn.LogoutClient = new(OAuth)
|
|||||||
var _ authn.RedirectClient = new(OAuth)
|
var _ authn.RedirectClient = new(OAuth)
|
||||||
|
|
||||||
func ProvideOAuth(
|
func ProvideOAuth(
|
||||||
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
name string, cfg *setting.Cfg, oauthService oauthtoken.OAuthTokenService,
|
||||||
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
|
socialService social.Service,
|
||||||
) *OAuth {
|
) *OAuth {
|
||||||
|
providerName := strings.TrimPrefix(name, "auth.client.")
|
||||||
return &OAuth{
|
return &OAuth{
|
||||||
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
name, fmt.Sprintf("oauth_%s", providerName), providerName,
|
||||||
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
|
log.New(name), cfg, oauthService, socialService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuth struct {
|
type OAuth struct {
|
||||||
name string
|
name string
|
||||||
moduleName string
|
moduleName string
|
||||||
|
providerName string
|
||||||
log log.Logger
|
log log.Logger
|
||||||
cfg *setting.Cfg
|
cfg *setting.Cfg
|
||||||
oauthCfg *social.OAuthInfo
|
|
||||||
connector social.SocialConnector
|
oauthService oauthtoken.OAuthTokenService
|
||||||
httpClient *http.Client
|
socialService social.Service
|
||||||
oauthService oauthtoken.OAuthTokenService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OAuth) Name() string {
|
func (c *OAuth) Name() string {
|
||||||
@@ -88,6 +91,12 @@ func (c *OAuth) Name() string {
|
|||||||
|
|
||||||
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
||||||
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
|
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
|
||||||
|
|
||||||
|
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
|
||||||
|
if !oauthCfg.Enabled {
|
||||||
|
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
|
||||||
|
}
|
||||||
|
|
||||||
// get hashed state stored in cookie
|
// get hashed state stored in cookie
|
||||||
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
|
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,7 +108,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get state returned by the idp and hash it
|
// get state returned by the idp and hash it
|
||||||
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, oauthCfg.ClientSecret)
|
||||||
// compare the state returned by idp against the one we stored in cookie
|
// compare the state returned by idp against the one we stored in cookie
|
||||||
if stateQuery != stateCookie.Value {
|
if stateQuery != stateCookie.Value {
|
||||||
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
|
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
|
||||||
@@ -107,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
|
|
||||||
var opts []oauth2.AuthCodeOption
|
var opts []oauth2.AuthCodeOption
|
||||||
// if pkce is enabled for client validate we have the cookie and set it as url param
|
// if pkce is enabled for client validate we have the cookie and set it as url param
|
||||||
if c.oauthCfg.UsePKCE {
|
if oauthCfg.UsePKCE {
|
||||||
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
|
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
|
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
|
||||||
@@ -115,15 +124,21 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
|
opts = append(opts, oauth2.VerifierOption(pkceCookie.Value))
|
||||||
}
|
}
|
||||||
|
|
||||||
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
connector, errConnector := c.socialService.GetConnector(c.providerName)
|
||||||
|
httpClient, errHTTPClient := c.socialService.GetOAuthHttpClient(c.providerName)
|
||||||
|
if errConnector != nil || errHTTPClient != nil {
|
||||||
|
return nil, errOAuthInternal.Errorf("failed to get %s oauth client: %w", c.name, errors.Join(errConnector, errHTTPClient))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
// exchange auth code to a valid token
|
// exchange auth code to a valid token
|
||||||
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
|
token, err := connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
|
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
|
||||||
}
|
}
|
||||||
token.TokenType = "Bearer"
|
token.TokenType = "Bearer"
|
||||||
|
|
||||||
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
|
userInfo, err := connector.UserInfo(ctx, connector.Client(clientCtx, token), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr *connectors.SocialError
|
var sErr *connectors.SocialError
|
||||||
if errors.As(err, &sErr) {
|
if errors.As(err, &sErr) {
|
||||||
@@ -136,7 +151,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
|
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.connector.IsEmailAllowed(userInfo.Email) {
|
if !connector.IsEmailAllowed(userInfo.Email) {
|
||||||
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
|
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +182,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
SyncTeams: true,
|
SyncTeams: true,
|
||||||
FetchSyncedUser: true,
|
FetchSyncedUser: true,
|
||||||
SyncPermissions: true,
|
SyncPermissions: true,
|
||||||
AllowSignUp: c.connector.IsSignupAllowed(),
|
AllowSignUp: connector.IsSignupAllowed(),
|
||||||
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
|
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
|
||||||
SyncOrgRoles: len(orgRoles) > 0,
|
SyncOrgRoles: len(orgRoles) > 0,
|
||||||
LookUpParams: lookupParams,
|
LookUpParams: lookupParams,
|
||||||
@@ -178,12 +193,17 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
|
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
|
||||||
var opts []oauth2.AuthCodeOption
|
var opts []oauth2.AuthCodeOption
|
||||||
|
|
||||||
if c.oauthCfg.HostedDomain != "" {
|
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
|
||||||
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
|
if !oauthCfg.Enabled {
|
||||||
|
return nil, errOAuthClientDisabled.Errorf("oauth client is disabled: %s", c.providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauthCfg.HostedDomain != "" {
|
||||||
|
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, oauthCfg.HostedDomain))
|
||||||
}
|
}
|
||||||
|
|
||||||
var plainPKCE string
|
var plainPKCE string
|
||||||
if c.oauthCfg.UsePKCE {
|
if oauthCfg.UsePKCE {
|
||||||
verifier, err := genPKCECodeVerifier()
|
verifier, err := genPKCECodeVerifier()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
|
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
|
||||||
@@ -193,13 +213,18 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
|
|||||||
opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
|
opts = append(opts, oauth2.S256ChallengeOption(plainPKCE))
|
||||||
}
|
}
|
||||||
|
|
||||||
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, oauthCfg.ClientSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
|
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connector, err := c.socialService.GetConnector(c.providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errOAuthInternal.Errorf("failed to get %s oauth connector: %w", c.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
return &authn.Redirect{
|
return &authn.Redirect{
|
||||||
URL: c.connector.AuthCodeURL(state, opts...),
|
URL: connector.AuthCodeURL(state, opts...),
|
||||||
Extra: map[string]string{
|
Extra: map[string]string{
|
||||||
authn.KeyOAuthState: hashedSate,
|
authn.KeyOAuthState: hashedSate,
|
||||||
authn.KeyOAuthPKCE: plainPKCE,
|
authn.KeyOAuthPKCE: plainPKCE,
|
||||||
@@ -215,19 +240,25 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login
|
|||||||
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
|
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
|
oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName)
|
||||||
if redirctURL == "" {
|
if !oauthCfg.Enabled {
|
||||||
|
c.log.FromContext(ctx).Debug("OAuth client is disabled")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg)
|
||||||
|
if redirectURL == "" {
|
||||||
c.log.FromContext(ctx).Debug("No signout redirect url configured")
|
c.log.FromContext(ctx).Debug("No signout redirect url configured")
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
|
if isOICDLogout(redirectURL) && token != nil && token.Valid() {
|
||||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||||
redirctURL = withIDTokenHint(redirctURL, idToken)
|
redirectURL = withIDTokenHint(redirectURL, idToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &authn.Redirect{URL: redirctURL}, true
|
return &authn.Redirect{URL: redirectURL}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.
|
// genPKCECodeVerifier returns code verifier that 128 characters random URL-friendly string.
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
@@ -46,17 +47,23 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
{
|
{
|
||||||
desc: "should return error when missing state cookie",
|
desc: "should return error when missing state cookie",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
||||||
oauthCfg: &social.OAuthInfo{},
|
oauthCfg: &social.OAuthInfo{Enabled: true},
|
||||||
expectedErr: errOAuthMissingState,
|
expectedErr: errOAuthMissingState,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when state cookie is present but don't have a value",
|
desc: "should return error when state cookie is present but don't have a value",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
||||||
oauthCfg: &social.OAuthInfo{},
|
oauthCfg: &social.OAuthInfo{Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "",
|
stateCookieValue: "",
|
||||||
expectedErr: errOAuthMissingState,
|
expectedErr: errOAuthMissingState,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "should return error when the client is not enabled",
|
||||||
|
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
||||||
|
oauthCfg: &social.OAuthInfo{Enabled: false},
|
||||||
|
expectedErr: errOAuthClientDisabled,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "should return error when state from ipd does not match stored state",
|
desc: "should return error when state from ipd does not match stored state",
|
||||||
req: &authn.Request{HTTPRequest: &http.Request{
|
req: &authn.Request{HTTPRequest: &http.Request{
|
||||||
@@ -64,7 +71,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
expectedErr: errOAuthInvalidState,
|
expectedErr: errOAuthInvalidState,
|
||||||
@@ -76,7 +83,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
expectedErr: errOAuthMissingPKCE,
|
expectedErr: errOAuthMissingPKCE,
|
||||||
@@ -88,7 +95,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
addPKCECookie: true,
|
addPKCECookie: true,
|
||||||
@@ -103,7 +110,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
addPKCECookie: true,
|
addPKCECookie: true,
|
||||||
@@ -119,7 +126,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
addPKCECookie: true,
|
addPKCECookie: true,
|
||||||
@@ -157,7 +164,7 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
allowInsecureTakeover: true,
|
allowInsecureTakeover: true,
|
||||||
addStateCookie: true,
|
addStateCookie: true,
|
||||||
stateCookieValue: "some-state",
|
stateCookieValue: "some-state",
|
||||||
@@ -211,12 +218,18 @@ func TestOAuth_Authenticate(t *testing.T) {
|
|||||||
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
|
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
|
||||||
}
|
}
|
||||||
|
|
||||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
|
fakeSocialSvc := &socialtest.FakeSocialService{
|
||||||
ExpectedUserInfo: tt.userInfo,
|
ExpectedAuthInfoProvider: tt.oauthCfg,
|
||||||
ExpectedToken: &oauth2.Token{},
|
ExpectedConnector: fakeConnector{
|
||||||
ExpectedIsSignupAllowed: true,
|
ExpectedUserInfo: tt.userInfo,
|
||||||
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
ExpectedToken: &oauth2.Token{},
|
||||||
}, nil, nil)
|
ExpectedIsSignupAllowed: true,
|
||||||
|
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc)
|
||||||
|
|
||||||
identity, err := c.Authenticate(context.Background(), tt.req)
|
identity, err := c.Authenticate(context.Background(), tt.req)
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
|
||||||
@@ -256,21 +269,27 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
|||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
desc: "should generate redirect url and state",
|
desc: "should generate redirect url and state",
|
||||||
oauthCfg: &social.OAuthInfo{},
|
oauthCfg: &social.OAuthInfo{Enabled: true},
|
||||||
authCodeUrlCalled: true,
|
authCodeUrlCalled: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should generate redirect url with hosted domain option if configured",
|
desc: "should generate redirect url with hosted domain option if configured",
|
||||||
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
|
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true},
|
||||||
numCallOptions: 1,
|
numCallOptions: 1,
|
||||||
authCodeUrlCalled: true,
|
authCodeUrlCalled: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "should generate redirect url with pkce if configured",
|
desc: "should generate redirect url with pkce if configured",
|
||||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
|
||||||
numCallOptions: 1,
|
numCallOptions: 1,
|
||||||
authCodeUrlCalled: true,
|
authCodeUrlCalled: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "should return error if the client is not enabled",
|
||||||
|
oauthCfg: &social.OAuthInfo{Enabled: false},
|
||||||
|
authCodeUrlCalled: false,
|
||||||
|
expectedErr: errOAuthClientDisabled,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -279,13 +298,18 @@ func TestOAuth_RedirectURL(t *testing.T) {
|
|||||||
authCodeUrlCalled = false
|
authCodeUrlCalled = false
|
||||||
)
|
)
|
||||||
|
|
||||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
|
fakeSocialSvc := &socialtest.FakeSocialService{
|
||||||
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
|
ExpectedAuthInfoProvider: tt.oauthCfg,
|
||||||
authCodeUrlCalled = true
|
ExpectedConnector: mockConnector{
|
||||||
require.Len(t, opts, tt.numCallOptions)
|
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
|
||||||
return ""
|
authCodeUrlCalled = true
|
||||||
|
require.Len(t, opts, tt.numCallOptions)
|
||||||
|
return ""
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}, nil, nil)
|
}
|
||||||
|
|
||||||
|
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), nil, fakeSocialSvc)
|
||||||
|
|
||||||
redirect, err := c.RedirectURL(context.Background(), nil)
|
redirect, err := c.RedirectURL(context.Background(), nil)
|
||||||
assert.ErrorIs(t, err, tt.expectedErr)
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
@@ -321,12 +345,17 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
cfg: &setting.Cfg{},
|
cfg: &setting.Cfg{},
|
||||||
oauthCfg: &social.OAuthInfo{},
|
oauthCfg: &social.OAuthInfo{},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "should not return redirect url when client is not enabled",
|
||||||
|
cfg: &setting.Cfg{},
|
||||||
|
oauthCfg: &social.OAuthInfo{Enabled: false},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "should return redirect url for globably configured redirect url",
|
desc: "should return redirect url for globably configured redirect url",
|
||||||
cfg: &setting.Cfg{
|
cfg: &setting.Cfg{
|
||||||
SignoutRedirectUrl: "http://idp.com/logout",
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{},
|
oauthCfg: &social.OAuthInfo{Enabled: true},
|
||||||
expectedURL: "http://idp.com/logout",
|
expectedURL: "http://idp.com/logout",
|
||||||
expectedOK: true,
|
expectedOK: true,
|
||||||
},
|
},
|
||||||
@@ -334,6 +363,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
desc: "should return redirect url for client configured redirect url",
|
desc: "should return redirect url for client configured redirect url",
|
||||||
cfg: &setting.Cfg{},
|
cfg: &setting.Cfg{},
|
||||||
oauthCfg: &social.OAuthInfo{
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
Enabled: true,
|
||||||
SignoutRedirectUrl: "http://idp.com/logout",
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
},
|
},
|
||||||
expectedURL: "http://idp.com/logout",
|
expectedURL: "http://idp.com/logout",
|
||||||
@@ -345,6 +375,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
SignoutRedirectUrl: "http://idp.com/logout",
|
SignoutRedirectUrl: "http://idp.com/logout",
|
||||||
},
|
},
|
||||||
oauthCfg: &social.OAuthInfo{
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
Enabled: true,
|
||||||
SignoutRedirectUrl: "http://idp-2.com/logout",
|
SignoutRedirectUrl: "http://idp-2.com/logout",
|
||||||
},
|
},
|
||||||
expectedURL: "http://idp-2.com/logout",
|
expectedURL: "http://idp-2.com/logout",
|
||||||
@@ -354,6 +385,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
desc: "should add id token hint if oicd logout is configured and token is valid",
|
desc: "should add id token hint if oicd logout is configured and token is valid",
|
||||||
cfg: &setting.Cfg{},
|
cfg: &setting.Cfg{},
|
||||||
oauthCfg: &social.OAuthInfo{
|
oauthCfg: &social.OAuthInfo{
|
||||||
|
Enabled: true,
|
||||||
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
|
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
|
||||||
},
|
},
|
||||||
expectedURL: "http://idp.com/logout",
|
expectedURL: "http://idp.com/logout",
|
||||||
@@ -387,7 +419,10 @@ func TestOAuth_Logout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
|
fakeSocialSvc := &socialtest.FakeSocialService{
|
||||||
|
ExpectedAuthInfoProvider: tt.oauthCfg,
|
||||||
|
}
|
||||||
|
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc)
|
||||||
|
|
||||||
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})
|
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})
|
||||||
|
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
|
|||||||
|
|
||||||
settings.Provider = key
|
settings.Provider = key
|
||||||
|
|
||||||
err := api.SSOSettingsService.Upsert(c.Req.Context(), settings)
|
err := api.SSOSettingsService.Upsert(c.Req.Context(), &settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err)
|
return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func TestSSOSettingsAPI_Update(t *testing.T) {
|
|||||||
|
|
||||||
service := ssosettingstests.NewMockService(t)
|
service := ssosettingstests.NewMockService(t)
|
||||||
if tt.expectedServiceCall {
|
if tt.expectedServiceCall {
|
||||||
service.On("Upsert", mock.Anything, settings).Return(tt.expectedError).Once()
|
service.On("Upsert", mock.Anything, &settings).Return(tt.expectedError).Once()
|
||||||
}
|
}
|
||||||
server := setupTests(t, service)
|
server := setupTests(t, service)
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
if settings.Provider == "" {
|
if settings.Provider == "" {
|
||||||
return ssosettings.ErrNotFound
|
return ssosettings.ErrNotFound
|
||||||
}
|
}
|
||||||
@@ -110,13 +110,10 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin
|
|||||||
}
|
}
|
||||||
_, err = sess.UseBool(isDeletedColumn).Update(updated, existing)
|
_, err = sess.UseBool(isDeletedColumn).Update(updated, existing)
|
||||||
} else {
|
} else {
|
||||||
_, err = sess.Insert(&models.SSOSettings{
|
settings.ID = uuid.New().String()
|
||||||
ID: uuid.New().String(),
|
settings.Created = now
|
||||||
Provider: settings.Provider,
|
settings.Updated = now
|
||||||
Settings: settings.Settings,
|
_, err = sess.Insert(settings)
|
||||||
Created: now,
|
|
||||||
Updated: now,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ssoSettingsStore.Upsert(context.Background(), settings)
|
err := ssoSettingsStore.Upsert(context.Background(), &settings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
|
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
|
||||||
@@ -143,7 +143,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
|||||||
"client_secret": "this-is-a-new-secret",
|
"client_secret": "this-is-a-new-secret",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
||||||
@@ -181,7 +181,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
||||||
@@ -217,7 +217,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
|||||||
"client_secret": "this-is-my-new-secret",
|
"client_secret": "this-is-my-new-secret",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
err = ssoSettingsStore.Upsert(context.Background(), &newSettings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
|
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
|
||||||
@@ -254,7 +254,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ssoSettingsStore.Upsert(context.Background(), settings)
|
err = ssoSettingsStore.Upsert(context.Background(), &settings)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, ssosettings.ErrNotFound)
|
require.ErrorIs(t, err, ssosettings.ErrNotFound)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrNotFound = errors.New("not found")
|
ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("provider is invalid"))
|
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("Provider is invalid"))
|
||||||
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("settings field is invalid"))
|
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("Settings field is invalid"))
|
||||||
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("settings.clientId cannot be empty"))
|
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("ClientId cannot be empty"))
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ type Service interface {
|
|||||||
// GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted
|
// GetForProviderWithRedactedSecrets returns the SSO settings for a given provider (DB or config file) with secret values redacted
|
||||||
GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error)
|
GetForProviderWithRedactedSecrets(ctx context.Context, provider string) (*models.SSOSettings, error)
|
||||||
// Upsert creates or updates the SSO settings for a given provider
|
// Upsert creates or updates the SSO settings for a given provider
|
||||||
Upsert(ctx context.Context, settings models.SSOSettings) error
|
Upsert(ctx context.Context, settings *models.SSOSettings) error
|
||||||
// Delete deletes the SSO settings for a given provider (soft delete)
|
// Delete deletes the SSO settings for a given provider (soft delete)
|
||||||
Delete(ctx context.Context, provider string) error
|
Delete(ctx context.Context, provider string) error
|
||||||
// Patch updates the specified SSO settings (key-value pairs) for a given provider
|
// Patch updates the specified SSO settings (key-value pairs) for a given provider
|
||||||
@@ -62,6 +62,6 @@ type FallbackStrategy interface {
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
Get(ctx context.Context, provider string) (*models.SSOSettings, error)
|
Get(ctx context.Context, provider string) (*models.SSOSettings, error)
|
||||||
List(ctx context.Context) ([]*models.SSOSettings, error)
|
List(ctx context.Context) ([]*models.SSOSettings, error)
|
||||||
Upsert(ctx context.Context, settings models.SSOSettings) error
|
Upsert(ctx context.Context, settings *models.SSOSettings) error
|
||||||
Delete(ctx context.Context, provider string) error
|
Delete(ctx context.Context, provider string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ func (s *SSOSettingsService) ListWithRedactedSecrets(ctx context.Context) ([]*mo
|
|||||||
return storeSettings, nil
|
return storeSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (s *SSOSettingsService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
if !isProviderConfigurable(settings.Provider) {
|
if !isProviderConfigurable(settings.Provider) {
|
||||||
return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider)
|
return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider)
|
||||||
}
|
}
|
||||||
@@ -157,7 +157,7 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
|
|||||||
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider)
|
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := social.Validate(ctx, settings)
|
err := social.Validate(ctx, *settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -167,6 +167,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
secrets := collectSecrets(settings, storedSettings)
|
||||||
|
|
||||||
settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings)
|
settings.Settings, err = s.encryptSecrets(ctx, settings.Settings, storedSettings.Settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -178,7 +180,8 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = social.Reload(context.Background(), settings)
|
settings.Settings = overrideMaps(storedSettings.Settings, settings.Settings, secrets)
|
||||||
|
err = social.Reload(context.Background(), *settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
|
s.logger.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
|
||||||
}
|
}
|
||||||
@@ -249,7 +252,7 @@ func (s *SSOSettingsService) getFallbackStrategyFor(provider string) (ssosetting
|
|||||||
func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) {
|
func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings, storedSettings map[string]any) (map[string]any, error) {
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
for k, v := range settings {
|
for k, v := range settings {
|
||||||
if isSecret(k) {
|
if isSecret(k) && v != "" {
|
||||||
strValue, ok := v.(string)
|
strValue, ok := v.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
||||||
@@ -315,20 +318,24 @@ func (s *SSOSettingsService) doReload(ctx context.Context) {
|
|||||||
|
|
||||||
// mergeSSOSettings merges the settings from the database with the system settings
|
// mergeSSOSettings merges the settings from the database with the system settings
|
||||||
// Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings)
|
// Required because it is possible that the user has configured some of the settings (current Advanced OAuth settings)
|
||||||
// and the rest of the settings are loaded from the system settings
|
// and the rest of the settings have to be loaded from the system settings
|
||||||
func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings {
|
func (s *SSOSettingsService) mergeSSOSettings(dbSettings, systemSettings *models.SSOSettings) *models.SSOSettings {
|
||||||
if dbSettings == nil {
|
if dbSettings == nil {
|
||||||
s.logger.Debug("No SSO Settings found in the database, using system settings")
|
s.logger.Debug("No SSO Settings found in the database, using system settings")
|
||||||
return systemSettings
|
return systemSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
s.logger.Debug("Merging SSO Settings", "dbSettings", dbSettings.Settings, "systemSettings", systemSettings.Settings)
|
s.logger.Debug("Merging SSO Settings", "dbSettings", removeSecrets(dbSettings.Settings), "systemSettings", removeSecrets(systemSettings.Settings))
|
||||||
|
|
||||||
finalSettings := mergeSettings(dbSettings.Settings, systemSettings.Settings)
|
result := &models.SSOSettings{
|
||||||
|
Provider: dbSettings.Provider,
|
||||||
|
Source: dbSettings.Source,
|
||||||
|
Settings: mergeSettings(dbSettings.Settings, systemSettings.Settings),
|
||||||
|
Created: dbSettings.Created,
|
||||||
|
Updated: dbSettings.Updated,
|
||||||
|
}
|
||||||
|
|
||||||
dbSettings.Settings = finalSettings
|
return result
|
||||||
|
|
||||||
return dbSettings
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
|
func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
|
||||||
@@ -358,6 +365,22 @@ func (s *SSOSettingsService) decryptSecrets(ctx context.Context, settings map[st
|
|||||||
return settings, nil
|
return settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeSecrets removes all the secrets from the map and replaces them with a redacted password
|
||||||
|
// and returns a new map
|
||||||
|
func removeSecrets(settings map[string]any) map[string]any {
|
||||||
|
result := make(map[string]any)
|
||||||
|
for k, v := range settings {
|
||||||
|
if isSecret(k) {
|
||||||
|
result[k] = setting.RedactedPassword
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeSettings merges two maps in a way that the values from the first map are preserved
|
||||||
|
// and the values from the second map are added only if they don't exist in the first map
|
||||||
func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any {
|
func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any {
|
||||||
settings := make(map[string]any)
|
settings := make(map[string]any)
|
||||||
|
|
||||||
@@ -374,6 +397,32 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
|
|||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collectSecrets collects all the secrets from the request and the currently stored settings
|
||||||
|
// and returns a new map
|
||||||
|
func collectSecrets(settings *models.SSOSettings, storedSettings *models.SSOSettings) map[string]any {
|
||||||
|
secrets := map[string]any{}
|
||||||
|
for k, v := range settings.Settings {
|
||||||
|
if isSecret(k) {
|
||||||
|
if isNewSecretValue(v.(string)) {
|
||||||
|
secrets[k] = v.(string) // use the new value
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
secrets[k] = storedSettings.Settings[k] // keep the currently stored value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return secrets
|
||||||
|
}
|
||||||
|
|
||||||
|
func overrideMaps(maps ...map[string]any) map[string]any {
|
||||||
|
result := make(map[string]any)
|
||||||
|
for _, m := range maps {
|
||||||
|
for k, v := range m {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func isSecret(fieldName string) bool {
|
func isSecret(fieldName string) bool {
|
||||||
secretFieldPatterns := []string{"secret"}
|
secretFieldPatterns := []string{"secret"}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -772,16 +775,50 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
},
|
},
|
||||||
IsDeleted: false,
|
IsDeleted: false,
|
||||||
}
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||||
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe()
|
reloadable.On("Reload", mock.Anything, mock.MatchedBy(func(settings models.SSOSettings) bool {
|
||||||
|
wg.Done()
|
||||||
|
return settings.Provider == provider &&
|
||||||
|
settings.ID == "someid" &&
|
||||||
|
maps.Equal(settings.Settings, map[string]any{
|
||||||
|
"client_id": "client-id",
|
||||||
|
"client_secret": "client-secret",
|
||||||
|
"enabled": true,
|
||||||
|
})
|
||||||
|
})).Return(nil).Once()
|
||||||
env.reloadables[provider] = reloadable
|
env.reloadables[provider] = reloadable
|
||||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||||
|
env.secrets.On("Decrypt", mock.Anything, []byte("encrypted-current-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
|
currentTime := time.Now()
|
||||||
|
settings.ID = "someid"
|
||||||
|
settings.Created = currentTime
|
||||||
|
settings.Updated = currentTime
|
||||||
|
|
||||||
|
env.store.ActualSSOSettings = *settings
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||||
|
return &models.SSOSettings{
|
||||||
|
ID: "someid",
|
||||||
|
Provider: provider,
|
||||||
|
Settings: map[string]any{
|
||||||
|
"client_secret": base64.RawStdEncoding.EncodeToString([]byte("encrypted-current-client-secret")),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for the goroutine first to assert the Reload call
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
|
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
|
||||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||||
})
|
})
|
||||||
@@ -790,7 +827,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
env := setupTestEnv(t)
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
provider := social.GrafanaComProviderName
|
provider := social.GrafanaComProviderName
|
||||||
settings := models.SSOSettings{
|
settings := &models.SSOSettings{
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
Settings: map[string]any{
|
Settings: map[string]any{
|
||||||
"client_id": "client-id",
|
"client_id": "client-id",
|
||||||
@@ -811,7 +848,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
env := setupTestEnv(t)
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
provider := social.AzureADProviderName
|
provider := social.AzureADProviderName
|
||||||
settings := models.SSOSettings{
|
settings := &models.SSOSettings{
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
Settings: map[string]any{
|
Settings: map[string]any{
|
||||||
"client_id": "client-id",
|
"client_id": "client-id",
|
||||||
@@ -847,14 +884,14 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed"))
|
reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed"))
|
||||||
env.reloadables[provider] = reloadable
|
env.reloadables[provider] = reloadable
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
|
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
|
||||||
env := setupTestEnv(t)
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
settings := models.SSOSettings{
|
settings := &models.SSOSettings{
|
||||||
Provider: social.AzureADProviderName,
|
Provider: social.AzureADProviderName,
|
||||||
Settings: map[string]any{
|
Settings: map[string]any{
|
||||||
"client_id": "client-id",
|
"client_id": "client-id",
|
||||||
@@ -889,7 +926,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
env.reloadables[provider] = reloadable
|
env.reloadables[provider] = reloadable
|
||||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -921,7 +958,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||||
env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
|
env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once()
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret"))
|
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret"))
|
||||||
@@ -950,11 +987,11 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
return &models.SSOSettings{}, nil
|
return &models.SSOSettings{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
env.store.UpsertFn = func(ctx context.Context, settings models.SSOSettings) error {
|
env.store.UpsertFn = func(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
return errors.New("failed to upsert settings")
|
return errors.New("failed to upsert settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -978,7 +1015,7 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
|||||||
env.reloadables[provider] = reloadable
|
env.reloadables[provider] = reloadable
|
||||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||||
|
|
||||||
err := env.service.Upsert(context.Background(), settings)
|
err := env.service.Upsert(context.Background(), &settings)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
|
settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret"))
|
||||||
@@ -1098,6 +1135,22 @@ func TestSSOSettingsService_decryptSecrets(t *testing.T) {
|
|||||||
"other_secret": "decrypted-other-secret",
|
"other_secret": "decrypted-other-secret",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "should not decrypt when a secret is empty",
|
||||||
|
setup: func(env testEnv) {
|
||||||
|
env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once()
|
||||||
|
},
|
||||||
|
settings: map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"client_secret": "",
|
||||||
|
"other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")),
|
||||||
|
},
|
||||||
|
want: map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"client_secret": "",
|
||||||
|
"other_secret": "decrypted-other-secret",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "should return an error if data is not a string",
|
name: "should return an error if data is not a string",
|
||||||
settings: map[string]any{
|
settings: map[string]any{
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by mockery v2.38.0. DO NOT EDIT.
|
// Code generated by mockery v2.40.1. DO NOT EDIT.
|
||||||
|
|
||||||
package ssosettingstests
|
package ssosettingstests
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by mockery v2.38.0. DO NOT EDIT.
|
// Code generated by mockery v2.40.1. DO NOT EDIT.
|
||||||
|
|
||||||
package ssosettingstests
|
package ssosettingstests
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ func (_m *MockService) Reload(ctx context.Context, provider string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upsert provides a mock function with given fields: ctx, settings
|
// Upsert provides a mock function with given fields: ctx, settings
|
||||||
func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (_m *MockService) Upsert(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
ret := _m.Called(ctx, settings)
|
ret := _m.Called(ctx, settings)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
@@ -191,7 +191,7 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings)
|
|||||||
}
|
}
|
||||||
|
|
||||||
var r0 error
|
var r0 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
|
||||||
r0 = rf(ctx, settings)
|
r0 = rf(ctx, settings)
|
||||||
} else {
|
} else {
|
||||||
r0 = ret.Error(0)
|
r0 = ret.Error(0)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type FakeStore struct {
|
|||||||
ActualSSOSettings models.SSOSettings
|
ActualSSOSettings models.SSOSettings
|
||||||
|
|
||||||
GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error)
|
GetFn func(ctx context.Context, provider string) (*models.SSOSettings, error)
|
||||||
UpsertFn func(ctx context.Context, settings models.SSOSettings) error
|
UpsertFn func(ctx context.Context, settings *models.SSOSettings) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFakeStore() *FakeStore {
|
func NewFakeStore() *FakeStore {
|
||||||
@@ -35,12 +35,12 @@ func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
|||||||
return f.ExpectedSSOSettings, f.ExpectedError
|
return f.ExpectedSSOSettings, f.ExpectedError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (f *FakeStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
if f.UpsertFn != nil {
|
if f.UpsertFn != nil {
|
||||||
return f.UpsertFn(ctx, settings)
|
return f.UpsertFn(ctx, settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.ActualSSOSettings = settings
|
f.ActualSSOSettings = *settings
|
||||||
|
|
||||||
return f.ExpectedError
|
return f.ExpectedError
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by mockery v2.39.2. DO NOT EDIT.
|
// Code generated by mockery v2.40.1. DO NOT EDIT.
|
||||||
|
|
||||||
package ssosettingstests
|
package ssosettingstests
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upsert provides a mock function with given fields: ctx, settings
|
// Upsert provides a mock function with given fields: ctx, settings
|
||||||
func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (_m *MockStore) Upsert(ctx context.Context, settings *models.SSOSettings) error {
|
||||||
ret := _m.Called(ctx, settings)
|
ret := _m.Called(ctx, settings)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
@@ -101,7 +101,7 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
var r0 error
|
var r0 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *models.SSOSettings) error); ok {
|
||||||
r0 = rf(ctx, settings)
|
r0 = rf(ctx, settings)
|
||||||
} else {
|
} else {
|
||||||
r0 = ret.Error(0)
|
r0 = ret.Error(0)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package strategies
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"maps"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||||
@@ -31,7 +32,10 @@ func (s *OAuthStrategy) IsMatch(provider string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) {
|
func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) {
|
||||||
return s.settingsByProvider[provider], nil
|
providerConfig := s.settingsByProvider[provider]
|
||||||
|
result := make(map[string]any, len(providerConfig))
|
||||||
|
maps.Copy(result, providerConfig)
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthStrategy) loadAllSettings() {
|
func (s *OAuthStrategy) loadAllSettings() {
|
||||||
|
|||||||
Reference in New Issue
Block a user