diff --git a/pkg/login/social/connectors/azuread_oauth.go b/pkg/login/social/connectors/azuread_oauth.go index 58e347f4bf7..86b22c042da 100644 --- a/pkg/login/social/connectors/azuread_oauth.go +++ b/pkg/login/social/connectors/azuread_oauth.go @@ -113,10 +113,12 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token return nil, ErrEmailNotFound } + info := s.GetOAuthInfo() + // setting the role, grafanaAdmin to empty to reflect that we are not syncronizing with the external provider var role roletype.RoleType var grafanaAdmin bool - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { role, grafanaAdmin, err = s.extractRoleAndAdmin(claims) if err != nil { return nil, err @@ -143,11 +145,11 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token } var isGrafanaAdmin *bool = nil - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { isGrafanaAdmin = &grafanaAdmin } - if s.info.AllowAssignGrafanaAdmin && s.info.SkipOrgRoleSync { + if info.AllowAssignGrafanaAdmin && info.SkipOrgRoleSync { s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } @@ -178,14 +180,6 @@ func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSett return nil } -func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - -func (s *SocialAzureAD) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - func (s *SocialAzureAD) validateClaims(ctx context.Context, client *http.Client, parsedToken *jwt.JSONWebToken) (*azureClaims, error) { claims, err := s.validateIDTokenSignature(ctx, client, parsedToken) if err != nil { @@ -257,8 +251,10 @@ func (claims *azureClaims) extractEmail() string { // extractRoleAndAdmin extracts the role from the claims and returns the role and whether the user is a Grafana admin. func (s *SocialAzureAD) extractRoleAndAdmin(claims *azureClaims) (org.RoleType, bool, error) { + info := s.GetOAuthInfo() + if len(claims.Roles) == 0 { - if s.info.RoleAttributeStrict { + if info.RoleAttributeStrict { return "", false, errRoleAttributeStrictViolation.Errorf("AzureAD OAuth: unset role") } return s.defaultRole(), false, nil @@ -276,7 +272,7 @@ func (s *SocialAzureAD) extractRoleAndAdmin(claims *azureClaims) (org.RoleType, } } - if s.info.RoleAttributeStrict { + if info.RoleAttributeStrict { return "", false, errRoleAttributeStrictViolation.Errorf("AzureAD OAuth: idP did not return a valid role %q", claims.Roles) } @@ -400,9 +396,11 @@ func (s *SocialAzureAD) groupsGraphAPIURL(claims *azureClaims, token *oauth2.Tok } func (s *SocialAzureAD) SupportBundleContent(bf *bytes.Buffer) error { + info := s.GetOAuthInfo() + bf.WriteString("## AzureAD specific configuration\n\n") bf.WriteString("```ini\n") - bf.WriteString(fmt.Sprintf("allowed_groups = %v\n", s.info.AllowedGroups)) + bf.WriteString(fmt.Sprintf("allowed_groups = %v\n", info.AllowedGroups)) bf.WriteString(fmt.Sprintf("forceUseGraphAPI = %v\n", s.forceUseGraphAPI)) bf.WriteString("```\n\n") diff --git a/pkg/login/social/connectors/azuread_oauth_test.go b/pkg/login/social/connectors/azuread_oauth_test.go index 2e4a9153731..18191e169c0 100644 --- a/pkg/login/social/connectors/azuread_oauth_test.go +++ b/pkg/login/social/connectors/azuread_oauth_test.go @@ -1045,3 +1045,67 @@ func TestSocialAzureAD_Validate(t *testing.T) { }) } } + +func TestSocialAzureAD_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewAzureADProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), nil) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/common.go b/pkg/login/social/connectors/common.go index d9700dc6860..93517fed981 100644 --- a/pkg/login/social/connectors/common.go +++ b/pkg/login/social/connectors/common.go @@ -38,11 +38,15 @@ type httpGetResponse struct { } func (s *SocialBase) IsEmailAllowed(email string) bool { - return isEmailAllowed(email, s.info.AllowedDomains) + info := s.GetOAuthInfo() + + return isEmailAllowed(email, info.AllowedDomains) } func (s *SocialBase) IsSignupAllowed() bool { - return s.info.AllowSignup + info := s.GetOAuthInfo() + + return info.AllowSignup } func isEmailAllowed(email string, allowedDomains []string) bool { diff --git a/pkg/login/social/connectors/generic_oauth.go b/pkg/login/social/connectors/generic_oauth.go index 6fd98b60e9f..fcc8f33ca56 100644 --- a/pkg/login/social/connectors/generic_oauth.go +++ b/pkg/login/social/connectors/generic_oauth.go @@ -84,17 +84,15 @@ func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SS return nil } -func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - // TODOD: remove this in the next PR and use the isGroupMember from social.go func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool { - if len(s.info.AllowedGroups) == 0 { + info := s.GetOAuthInfo() + + if len(info.AllowedGroups) == 0 { return true } - for _, allowedGroup := range s.info.AllowedGroups { + for _, allowedGroup := range info.AllowedGroups { for _, group := range groups { if group == allowedGroup { return true @@ -177,6 +175,8 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, toCheck = append(toCheck, apiData) } + info := s.GetOAuthInfo() + userInfo := &social.BasicUserInfo{} for _, data := range toCheck { s.log.Debug("Processing external user info", "source", data.source, "data", data) @@ -200,13 +200,13 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, } } - if userInfo.Role == "" && !s.info.SkipOrgRoleSync { + if userInfo.Role == "" && !info.SkipOrgRoleSync { role, grafanaAdmin, err := s.extractRoleAndAdminOptional(data.rawJSON, []string{}) if err != nil { s.log.Warn("Failed to extract role", "err", err) } else { userInfo.Role = role - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { userInfo.IsGrafanaAdmin = &grafanaAdmin } } @@ -223,14 +223,14 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, } } - if userInfo.Role == "" && !s.info.SkipOrgRoleSync { - if s.info.RoleAttributeStrict { + if userInfo.Role == "" && !info.SkipOrgRoleSync { + if info.RoleAttributeStrict { return nil, errRoleAttributeStrictViolation.Errorf("idP did not return a role attribute") } userInfo.Role = s.defaultRole() } - if s.info.AllowAssignGrafanaAdmin && s.info.SkipOrgRoleSync { + if info.AllowAssignGrafanaAdmin && info.SkipOrgRoleSync { s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } @@ -264,10 +264,6 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, return userInfo, nil } -func (s *SocialGenericOAuth) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson { s.log.Debug("Extracting user info from OAuth token") @@ -302,15 +298,17 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson } func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson { + info := s.GetOAuthInfo() + s.log.Debug("Getting user info from API") - if s.info.ApiUrl == "" { + if info.ApiUrl == "" { s.log.Debug("No api url configured") return nil } - rawUserInfoResponse, err := s.httpGet(ctx, client, s.info.ApiUrl) + rawUserInfoResponse, err := s.httpGet(ctx, client, info.ApiUrl) if err != nil { - s.log.Debug("Error getting user info from API", "url", s.info.ApiUrl, "error", err) + s.log.Debug("Error getting user info from API", "url", info.ApiUrl, "error", err) return nil } @@ -426,9 +424,11 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http IsConfirmed bool `json:"is_confirmed"` } - response, err := s.httpGet(ctx, client, fmt.Sprintf(s.info.ApiUrl+"/emails")) + info := s.GetOAuthInfo() + + response, err := s.httpGet(ctx, client, fmt.Sprintf(info.ApiUrl+"/emails")) if err != nil { - s.log.Error("Error getting email address", "url", s.info.ApiUrl+"/emails", "error", err) + s.log.Error("Error getting email address", "url", info.ApiUrl+"/emails", "error", err) return "", fmt.Errorf("%v: %w", "Error getting email address", err) } @@ -488,9 +488,11 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx cont Id int `json:"id"` } - response, err := s.httpGet(ctx, client, fmt.Sprintf(s.info.ApiUrl+"/teams")) + info := s.GetOAuthInfo() + + response, err := s.httpGet(ctx, client, fmt.Sprintf(info.ApiUrl+"/teams")) if err != nil { - s.log.Error("Error getting team memberships", "url", s.info.ApiUrl+"/teams", "error", err) + s.log.Error("Error getting team memberships", "url", info.ApiUrl+"/teams", "error", err) return []string{}, err } @@ -529,9 +531,11 @@ func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *htt Login string `json:"login"` } - response, err := s.httpGet(ctx, client, fmt.Sprintf(s.info.ApiUrl+"/orgs")) + info := s.GetOAuthInfo() + + response, err := s.httpGet(ctx, client, fmt.Sprintf(info.ApiUrl+"/orgs")) if err != nil { - s.log.Error("Error getting organizations", "url", s.info.ApiUrl+"/orgs", "error", err) + s.log.Error("Error getting organizations", "url", info.ApiUrl+"/orgs", "error", err) return nil, false } diff --git a/pkg/login/social/connectors/generic_oauth_test.go b/pkg/login/social/connectors/generic_oauth_test.go index 655f88875cc..d19bcabd63f 100644 --- a/pkg/login/social/connectors/generic_oauth_test.go +++ b/pkg/login/social/connectors/generic_oauth_test.go @@ -973,3 +973,67 @@ func TestSocialGenericOAuth_Validate(t *testing.T) { }) } } + +func TestSocialGenericOAuth_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGenericOAuthProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/github_oauth.go b/pkg/login/social/connectors/github_oauth.go index 5863b453d45..749596df075 100644 --- a/pkg/login/social/connectors/github_oauth.go +++ b/pkg/login/social/connectors/github_oauth.go @@ -91,10 +91,6 @@ func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSetti return nil } -func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool { if len(s.teamIds) == 0 { return true @@ -145,7 +141,9 @@ func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Clien Verified bool `json:"verified"` } - response, err := s.httpGet(ctx, client, fmt.Sprintf(s.info.ApiUrl+"/emails")) + info := s.GetOAuthInfo() + + response, err := s.httpGet(ctx, client, fmt.Sprintf(info.ApiUrl+"/emails")) if err != nil { return "", fmt.Errorf("Error getting email address: %s", err) } @@ -168,7 +166,9 @@ func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Clien } func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) { - url := fmt.Sprintf(s.info.ApiUrl + "/teams?per_page=100") + info := s.GetOAuthInfo() + + url := fmt.Sprintf(info.ApiUrl + "/teams?per_page=100") hasMore := true teams := make([]GithubTeam, 0) @@ -250,7 +250,9 @@ func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token Name string `json:"name"` } - response, err := s.httpGet(ctx, client, s.info.ApiUrl) + info := s.GetOAuthInfo() + + response, err := s.httpGet(ctx, client, info.ApiUrl) if err != nil { return nil, fmt.Errorf("error getting user info: %s", err) } @@ -269,20 +271,20 @@ func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token var role roletype.RoleType var isGrafanaAdmin *bool = nil - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { var grafanaAdmin bool role, grafanaAdmin, err = s.extractRoleAndAdmin(response.Body, teams) if err != nil { return nil, err } - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { isGrafanaAdmin = &grafanaAdmin } } // we skip allowing assignment of GrafanaAdmin if skipOrgRoleSync is present - if s.info.AllowAssignGrafanaAdmin && s.info.SkipOrgRoleSync { + if info.AllowAssignGrafanaAdmin && info.SkipOrgRoleSync { s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } @@ -299,7 +301,7 @@ func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token userInfo.Name = data.Name } - organizationsUrl := fmt.Sprintf(s.info.ApiUrl + "/orgs?per_page=100") + organizationsUrl := fmt.Sprintf(info.ApiUrl + "/orgs?per_page=100") if !s.IsTeamMember(ctx, client) { return nil, ErrMissingTeamMembership.Errorf("User is not a member of any of the allowed teams: %v", s.teamIds) @@ -328,10 +330,6 @@ func (t *GithubTeam) GetShorthand() (string, error) { return fmt.Sprintf("@%s/%s", t.Organization.Login, t.Slug), nil } -func (s *SocialGithub) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - func convertToGroupList(t []GithubTeam) []string { groups := make([]string, 0) for _, team := range t { diff --git a/pkg/login/social/connectors/github_oauth_test.go b/pkg/login/social/connectors/github_oauth_test.go index bbc1c23f8da..9b6613a0e0a 100644 --- a/pkg/login/social/connectors/github_oauth_test.go +++ b/pkg/login/social/connectors/github_oauth_test.go @@ -399,3 +399,67 @@ func TestSocialGitHub_Validate(t *testing.T) { }) } } + +func TestSocialGitHub_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGitHubProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/gitlab_oauth.go b/pkg/login/social/connectors/gitlab_oauth.go index 385882a8a6f..fbdafd557dc 100644 --- a/pkg/login/social/connectors/gitlab_oauth.go +++ b/pkg/login/social/connectors/gitlab_oauth.go @@ -81,10 +81,6 @@ func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSetti return nil } -func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - func (s *SocialGitlab) getGroups(ctx context.Context, client *http.Client) []string { groups := make([]string, 0) nextPage := new(int) @@ -104,7 +100,9 @@ func (s *SocialGitlab) getGroupsPage(ctx context.Context, client *http.Client, n FullPath string `json:"full_path"` } - groupURL, err := url.JoinPath(s.info.ApiUrl, "/groups") + info := s.GetOAuthInfo() + + groupURL, err := url.JoinPath(info.ApiUrl, "/groups") if err != nil { s.log.Error("Error joining GitLab API URL", "err", err) return nil, nil @@ -165,6 +163,8 @@ func (s *SocialGitlab) getGroupsPage(ctx context.Context, client *http.Client, n } func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { + info := s.GetOAuthInfo() + data, err := s.extractFromToken(ctx, client, token) if err != nil { return nil, err @@ -193,20 +193,18 @@ func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token return nil, errMissingGroupMembership } - if s.info.AllowAssignGrafanaAdmin && s.info.SkipOrgRoleSync { + if info.AllowAssignGrafanaAdmin && info.SkipOrgRoleSync { s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } return userInfo, nil } -func (s *SocialGitlab) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - func (s *SocialGitlab) extractFromAPI(ctx context.Context, client *http.Client, token *oauth2.Token) (*userData, error) { + info := s.GetOAuthInfo() + apiResp := &apiData{} - response, err := s.httpGet(ctx, client, s.info.ApiUrl+"/user") + response, err := s.httpGet(ctx, client, info.ApiUrl+"/user") if err != nil { return nil, fmt.Errorf("Error getting user info: %w", err) } @@ -232,14 +230,14 @@ func (s *SocialGitlab) extractFromAPI(ctx context.Context, client *http.Client, Groups: s.getGroups(ctx, client), } - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { var grafanaAdmin bool role, grafanaAdmin, err := s.extractRoleAndAdmin(response.Body, idData.Groups) if err != nil { return nil, err } - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { idData.IsGrafanaAdmin = &grafanaAdmin } @@ -256,6 +254,8 @@ func (s *SocialGitlab) extractFromAPI(ctx context.Context, client *http.Client, func (s *SocialGitlab) extractFromToken(ctx context.Context, client *http.Client, token *oauth2.Token) (*userData, error) { s.log.Debug("Extracting user info from OAuth token") + info := s.GetOAuthInfo() + idToken := token.Extra("id_token") if idToken == nil { s.log.Debug("No id_token found, defaulting to API access", "token", token) @@ -289,13 +289,13 @@ func (s *SocialGitlab) extractFromToken(ctx context.Context, client *http.Client data.Groups = userInfo.Groups } - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { role, grafanaAdmin, errRole := s.extractRoleAndAdmin(rawJSON, data.Groups) if errRole != nil { return nil, errRole } - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { data.IsGrafanaAdmin = &grafanaAdmin } diff --git a/pkg/login/social/connectors/gitlab_oauth_test.go b/pkg/login/social/connectors/gitlab_oauth_test.go index df0c3100cbb..1eeac7a7f86 100644 --- a/pkg/login/social/connectors/gitlab_oauth_test.go +++ b/pkg/login/social/connectors/gitlab_oauth_test.go @@ -517,3 +517,67 @@ func TestSocialGitlab_Validate(t *testing.T) { }) } } + +func TestSocialGitlab_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGitLabProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/google_oauth.go b/pkg/login/social/connectors/google_oauth.go index 36849984755..cf74c150d5c 100644 --- a/pkg/login/social/connectors/google_oauth.go +++ b/pkg/login/social/connectors/google_oauth.go @@ -71,11 +71,9 @@ func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSetti return nil } -func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { + info := s.GetOAuthInfo() + data, errToken := s.extractFromToken(ctx, client, token) if errToken != nil { return nil, errToken @@ -116,13 +114,13 @@ func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token Groups: groups, } - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { role, grafanaAdmin, errRole := s.extractRoleAndAdmin(data.rawJSON, groups) if errRole != nil { return nil, errRole } - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { userInfo.IsGrafanaAdmin = &grafanaAdmin } @@ -134,10 +132,6 @@ func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token return userInfo, nil } -func (s *SocialGoogle) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - type googleAPIData struct { ID string `json:"id"` Name string `json:"name"` @@ -146,9 +140,11 @@ type googleAPIData struct { } func (s *SocialGoogle) extractFromAPI(ctx context.Context, client *http.Client) (*googleUserData, error) { - if strings.HasPrefix(s.info.ApiUrl, legacyAPIURL) { + info := s.GetOAuthInfo() + + if strings.HasPrefix(info.ApiUrl, legacyAPIURL) { data := googleAPIData{} - response, err := s.httpGet(ctx, client, s.info.ApiUrl) + response, err := s.httpGet(ctx, client, info.ApiUrl) if err != nil { return nil, fmt.Errorf("error retrieving legacy user info: %s", err) } @@ -167,7 +163,7 @@ func (s *SocialGoogle) extractFromAPI(ctx context.Context, client *http.Client) } data := googleUserData{} - response, err := s.httpGet(ctx, client, s.info.ApiUrl) + response, err := s.httpGet(ctx, client, info.ApiUrl) if err != nil { return nil, fmt.Errorf("error getting user info: %s", err) } @@ -180,7 +176,9 @@ func (s *SocialGoogle) extractFromAPI(ctx context.Context, client *http.Client) } func (s *SocialGoogle) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - if s.info.UseRefreshToken { + info := s.GetOAuthInfo() + + if info.UseRefreshToken { opts = append(opts, oauth2.AccessTypeOffline, oauth2.ApprovalForce) } return s.SocialBase.AuthCodeURL(state, opts...) diff --git a/pkg/login/social/connectors/google_oauth_test.go b/pkg/login/social/connectors/google_oauth_test.go index c386311cb6c..5cd5627d589 100644 --- a/pkg/login/social/connectors/google_oauth_test.go +++ b/pkg/login/social/connectors/google_oauth_test.go @@ -722,3 +722,67 @@ func TestSocialGoogle_Validate(t *testing.T) { }) } } + +func TestSocialGoogle_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewGoogleProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/grafana_com_oauth.go b/pkg/login/social/connectors/grafana_com_oauth.go index 276548447ff..ced97b6459b 100644 --- a/pkg/login/social/connectors/grafana_com_oauth.go +++ b/pkg/login/social/connectors/grafana_com_oauth.go @@ -69,10 +69,6 @@ func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOS return nil } -func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { return true } @@ -104,6 +100,8 @@ func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ Orgs []OrgRecord `json:"orgs"` } + info := s.GetOAuthInfo() + response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user") if err != nil { @@ -117,7 +115,7 @@ func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ // on login we do not want to display the role from the external provider var role roletype.RoleType - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { role = org.RoleType(data.Role) } userInfo := &social.BasicUserInfo{ @@ -136,7 +134,3 @@ func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ return userInfo, nil } - -func (s *SocialGrafanaCom) GetOAuthInfo() *social.OAuthInfo { - return s.info -} diff --git a/pkg/login/social/connectors/grafana_com_oauth_test.go b/pkg/login/social/connectors/grafana_com_oauth_test.go index 9f827bcfc7c..340650f5079 100644 --- a/pkg/login/social/connectors/grafana_com_oauth_test.go +++ b/pkg/login/social/connectors/grafana_com_oauth_test.go @@ -190,3 +190,76 @@ func TestSocialGrafanaCom_Validate(t *testing.T) { }) } } + +func TestSocialGrafanaCom_Reload(t *testing.T) { + const GrafanaComURL = "http://localhost:3000" + + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "name": "a-new-name", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + Name: "a-new-name", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + // these are the overwrites from the constructor + AuthUrl: GrafanaComURL + "/oauth2/authorize", + TokenUrl: GrafanaComURL + "/api/oauth2/token", + AuthStyle: "inheader", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cfg := &setting.Cfg{ + GrafanaComURL: GrafanaComURL, + } + s := NewGrafanaComProvider(tc.info, cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/okta_oauth.go b/pkg/login/social/connectors/okta_oauth.go index 84a791db155..aab500562b4 100644 --- a/pkg/login/social/connectors/okta_oauth.go +++ b/pkg/login/social/connectors/okta_oauth.go @@ -77,10 +77,6 @@ func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSetting return nil } -func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { - return nil -} - func (claims *OktaClaims) extractEmail() string { if claims.Email == "" && claims.PreferredUsername != "" { return claims.PreferredUsername @@ -90,6 +86,8 @@ func (claims *OktaClaims) extractEmail() string { } func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { + info := s.GetOAuthInfo() + idToken := token.Extra("id_token") if idToken == nil { return nil, fmt.Errorf("no id_token found") @@ -123,18 +121,18 @@ func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *o var role roletype.RoleType var isGrafanaAdmin *bool - if !s.info.SkipOrgRoleSync { + if !info.SkipOrgRoleSync { var grafanaAdmin bool role, grafanaAdmin, err = s.extractRoleAndAdmin(data.rawJSON, groups) if err != nil { return nil, err } - if s.info.AllowAssignGrafanaAdmin { + if info.AllowAssignGrafanaAdmin { isGrafanaAdmin = &grafanaAdmin } } - if s.info.AllowAssignGrafanaAdmin && s.info.SkipOrgRoleSync { + if info.AllowAssignGrafanaAdmin && info.SkipOrgRoleSync { s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } @@ -149,14 +147,12 @@ func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *o }, nil } -func (s *SocialOkta) GetOAuthInfo() *social.OAuthInfo { - return s.info -} - func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error { - rawUserInfoResponse, err := s.httpGet(ctx, client, s.info.ApiUrl) + info := s.GetOAuthInfo() + + rawUserInfoResponse, err := s.httpGet(ctx, client, info.ApiUrl) if err != nil { - s.log.Debug("Error getting user info response", "url", s.info.ApiUrl, "error", err) + s.log.Debug("Error getting user info response", "url", info.ApiUrl, "error", err) return fmt.Errorf("error getting user info response: %w", err) } data.rawJSON = rawUserInfoResponse.Body @@ -182,11 +178,13 @@ func (s *SocialOkta) GetGroups(data *OktaUserInfoJson) []string { // TODO: remove this in a separate PR and use the isGroupMember from the social.go func (s *SocialOkta) IsGroupMember(groups []string) bool { - if len(s.info.AllowedGroups) == 0 { + info := s.GetOAuthInfo() + + if len(info.AllowedGroups) == 0 { return true } - for _, allowedGroup := range s.info.AllowedGroups { + for _, allowedGroup := range info.AllowedGroups { for _, group := range groups { if group == allowedGroup { return true diff --git a/pkg/login/social/connectors/okta_oauth_test.go b/pkg/login/social/connectors/okta_oauth_test.go index 2ebcdf50af5..f0d1ca82fe8 100644 --- a/pkg/login/social/connectors/okta_oauth_test.go +++ b/pkg/login/social/connectors/okta_oauth_test.go @@ -190,3 +190,67 @@ func TestSocialOkta_Validate(t *testing.T) { }) } } + +func TestSocialOkta_Reload(t *testing.T) { + testCases := []struct { + name string + info *social.OAuthInfo + settings ssoModels.SSOSettings + expectError bool + expectedInfo *social.OAuthInfo + }{ + { + name: "SSO provider successfully updated", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": "some-new-url", + }, + }, + expectError: false, + expectedInfo: &social.OAuthInfo{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + AuthUrl: "some-new-url", + }, + }, + { + name: "fails if settings contain invalid values", + info: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + settings: ssoModels.SSOSettings{ + Settings: map[string]any{ + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "auth_url": []string{"first", "second"}, + }, + }, + expectError: true, + expectedInfo: &social.OAuthInfo{ + ClientId: "client-id", + ClientSecret: "client-secret", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := NewOktaProvider(tc.info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) + + err := s.Reload(context.Background(), tc.settings) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.EqualValues(t, tc.expectedInfo, s.info) + }) + } +} diff --git a/pkg/login/social/connectors/social_base.go b/pkg/login/social/connectors/social_base.go index e5a9ce5c211..62c4e8247b2 100644 --- a/pkg/login/social/connectors/social_base.go +++ b/pkg/login/social/connectors/social_base.go @@ -3,12 +3,14 @@ package connectors import ( "bytes" "compress/zlib" + "context" "encoding/base64" "encoding/json" "fmt" "io" "regexp" "strings" + "sync" "golang.org/x/oauth2" "golang.org/x/text/cases" @@ -19,11 +21,13 @@ import ( "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" ) type SocialBase struct { *oauth2.Config info *social.OAuthInfo + infoMutex sync.RWMutex log log.Logger autoAssignOrgRole string features featuremgmt.FeatureToggles @@ -71,6 +75,27 @@ func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error { return nil } +func (s *SocialBase) GetOAuthInfo() *social.OAuthInfo { + s.infoMutex.RLock() + defer s.infoMutex.RUnlock() + + return s.info +} + +func (s *SocialBase) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + info, err := CreateOAuthInfoFromKeyValues(settings.Settings) + if err != nil { + return fmt.Errorf("SSO settings map cannot be converted to OAuthInfo: %v", err) + } + + s.infoMutex.Lock() + defer s.infoMutex.Unlock() + + s.info = info + + return nil +} + func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) { if s.info.RoleAttributePath == "" { if s.info.RoleAttributeStrict {