diff --git a/pkg/api/admin_users_test.go b/pkg/api/admin_users_test.go index c6504f24d2c..3ec408ad1de 100644 --- a/pkg/api/admin_users_test.go +++ b/pkg/api/admin_users_test.go @@ -15,7 +15,7 @@ import ( "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db/dbtest" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/login/socialtest" + "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" diff --git a/pkg/api/frontendsettings_test.go b/pkg/api/frontendsettings_test.go index 6fbbdc3cfa9..6d3726ada27 100644 --- a/pkg/api/frontendsettings_test.go +++ b/pkg/api/frontendsettings_test.go @@ -14,7 +14,7 @@ import ( "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/usagestats" - "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/socialimpl" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/config" "github.com/grafana/grafana/pkg/plugins/pluginscdn" @@ -25,6 +25,7 @@ import ( "github.com/grafana/grafana/pkg/services/pluginsintegration/pluginsettings" "github.com/grafana/grafana/pkg/services/pluginsintegration/pluginstore" "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest" "github.com/grafana/grafana/pkg/services/updatechecker" "github.com/grafana/grafana/pkg/setting" @@ -77,7 +78,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt. PluginSettings: cfg.PluginSettings, }), namespacer: request.GetNamespaceMapper(cfg), - SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()), + SocialService: socialimpl.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage(), &ssosettingstests.MockService{}), } m := web.New() diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index 2aa10271a73..08bf2833348 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -19,7 +19,7 @@ import ( "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db/dbtest" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/login/socialtest" + "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" diff --git a/pkg/login/social/azuread_jwks.go b/pkg/login/social/connectors/azuread_jwks.go similarity index 99% rename from pkg/login/social/azuread_jwks.go rename to pkg/login/social/connectors/azuread_jwks.go index 74113a84271..b6d051b08cd 100644 --- a/pkg/login/social/azuread_jwks.go +++ b/pkg/login/social/connectors/azuread_jwks.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "bytes" diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/connectors/azuread_oauth.go similarity index 85% rename from pkg/login/social/azuread_oauth.go rename to pkg/login/social/connectors/azuread_oauth.go index a1288926115..d2a5f9812dd 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/connectors/azuread_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "bytes" @@ -15,24 +15,25 @@ import ( "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) -const ( - AzureADProviderName = "azuread" - forceUseGraphAPIKey = "force_use_graph_api" // #nosec G101 not a hardcoded credential -) +const forceUseGraphAPIKey = "force_use_graph_api" // #nosec G101 not a hardcoded credential var ( ExtraAzureADSettingKeys = []string{forceUseGraphAPIKey, allowedOrganizationsKey} - errAzureADMissingGroups = &Error{"either the user does not have any group membership or the groups claim is missing from the token."} + errAzureADMissingGroups = &SocialError{"either the user does not have any group membership or the groups claim is missing from the token."} ) -var _ SocialConnector = (*SocialAzureAD)(nil) +var _ social.SocialConnector = (*SocialAzureAD)(nil) +var _ ssosettings.Reloadable = (*SocialAzureAD)(nil) type SocialAzureAD struct { *SocialBase @@ -72,15 +73,10 @@ type keySetJWKS struct { jose.JSONWebKeySet } -func NewAzureADProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (*SocialAzureAD, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - - config := createOAuthConfig(info, cfg, AzureADProviderName) +func NewAzureADProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) *SocialAzureAD { + config := createOAuthConfig(info, cfg, social.AzureADProviderName) provider := &SocialAzureAD{ - SocialBase: newSocialBase(AzureADProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.AzureADProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), cache: cache, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), forceUseGraphAPI: MustBool(info.Extra[forceUseGraphAPIKey], false), @@ -90,13 +86,17 @@ func NewAzureADProvider(settings map[string]any, cfg *setting.Cfg, features *fea } if info.UseRefreshToken && features.IsEnabledGlobally(featuremgmt.FlagAccessTokenExpirationCheck) { - appendUniqueScope(config, OfflineAccessScope) + appendUniqueScope(config, social.OfflineAccessScope) } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.AzureADProviderName, provider) + } + + return provider } -func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { return nil, ErrIDTokenNotFound @@ -155,7 +155,7 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } - return &BasicUserInfo{ + return &social.BasicUserInfo{ Id: claims.ID, Name: claims.Name, Email: email, @@ -166,7 +166,15 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token }, nil } -func (s *SocialAzureAD) GetOAuthInfo() *OAuthInfo { +func (s *SocialAzureAD) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialAzureAD) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialAzureAD) GetOAuthInfo() *social.OAuthInfo { return s.info } @@ -177,17 +185,17 @@ func (s *SocialAzureAD) validateClaims(ctx context.Context, client *http.Client, } if claims.OAuthVersion == "1.0" { - return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."} + return nil, &SocialError{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."} } s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID) if claims.Audience != s.ClientID { - return nil, &Error{"AzureAD OAuth: audience mismatch"} + return nil, &SocialError{"AzureAD OAuth: audience mismatch"} } s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations) if !s.isAllowedTenant(claims.TenantID) { - return nil, &Error{"AzureAD OAuth: tenant mismatch"} + return nil, &SocialError{"AzureAD OAuth: tenant mismatch"} } return claims, nil } @@ -226,7 +234,7 @@ func (s *SocialAzureAD) validateIDTokenSignature(ctx context.Context, client *ht s.log.Warn("AzureAD OAuth: signing key not found", "kid", keyID) - return nil, &Error{"AzureAD OAuth: signing key not found"} + return nil, &SocialError{"AzureAD OAuth: signing key not found"} } func (claims *azureClaims) extractEmail() string { @@ -248,11 +256,11 @@ func (s *SocialAzureAD) extractRoleAndAdmin(claims *azureClaims) (org.RoleType, return s.defaultRole(), false, nil } - roleOrder := []org.RoleType{RoleGrafanaAdmin, org.RoleAdmin, org.RoleEditor, + roleOrder := []org.RoleType{social.RoleGrafanaAdmin, org.RoleAdmin, org.RoleEditor, org.RoleViewer, org.RoleNone} for _, role := range roleOrder { if found := hasRole(claims.Roles, role); found { - if role == RoleGrafanaAdmin { + if role == social.RoleGrafanaAdmin { return org.RoleAdmin, true, nil } diff --git a/pkg/login/social/azuread_oauth_test.go b/pkg/login/social/connectors/azuread_oauth_test.go similarity index 79% rename from pkg/login/social/azuread_oauth_test.go rename to pkg/login/social/connectors/azuread_oauth_test.go index dada74594d2..543c993c6f5 100644 --- a/pkg/login/social/azuread_oauth_test.go +++ b/pkg/login/social/connectors/azuread_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -17,7 +17,9 @@ import ( "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) @@ -33,7 +35,7 @@ func falseBoolPtr() *bool { func TestSocialAzureAD_UserInfo(t *testing.T) { type fields struct { - providerCfg map[string]any + providerCfg *social.OAuthInfo cfg *setting.Cfg usGovURL bool } @@ -47,7 +49,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { claims *azureClaims args args settingAutoAssignOrgRole string - want *BasicUserInfo + want *social.BasicUserInfo wantErr bool }{ { @@ -60,15 +62,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", }, }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -80,9 +82,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "No email", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -102,9 +104,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { name: "No id token", claims: nil, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -123,16 +125,16 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", }, usGovURL: true, }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -151,15 +153,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", }, }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -171,9 +173,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Admin role", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -186,7 +188,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -198,9 +200,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Lowercase Admin role", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -213,7 +215,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -225,9 +227,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Only other roles", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -240,7 +242,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -260,15 +262,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", }, }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -287,15 +289,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", }, }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -307,9 +309,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Admin and Editor roles in claim", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", @@ -322,7 +324,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -334,10 +336,10 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Grafana Admin but setting is disabled", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": false, + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: false, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", @@ -351,7 +353,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -364,10 +366,10 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Editor roles in claim and GrafanaAdminAssignment enabled", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": true, + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: true, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -380,7 +382,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -393,10 +395,10 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Grafana Admin and Editor roles in claim", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": true, + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: true, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -409,7 +411,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -422,11 +424,11 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Error if user is not a member of allowed_groups", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": false, - "allowed_groups": "dead-beef", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: false, + AllowedGroups: []string{"dead-beef"}, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", @@ -446,11 +448,13 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Error if user is not a member of allowed_organizations", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": false, - "allowed_organizations": "uuid-1234", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: false, + Extra: map[string]string{ + "allowed_organizations": "uuid-1234", + }, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Editor", @@ -471,10 +475,12 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "No error if user is a member of allowed_organizations", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allowed_organizations": "uuid-1234,uuid-5678", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -489,7 +495,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -502,11 +508,11 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "No Error if user is a member of allowed_groups", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": "false", - "allowed_groups": "foo, bar", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: false, + AllowedGroups: []string{"foo", "bar"}, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -520,7 +526,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -532,11 +538,11 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Error if user does not have groups but allowed groups", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": "false", - "allowed_groups": "foo, bar", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: false, + AllowedGroups: []string{"foo", "bar"}, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "Viewer", @@ -556,9 +562,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch groups when ClaimsNames and ClaimsSources is set", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -574,7 +580,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ClaimSources: nil, // set by the test }, settingAutoAssignOrgRole: "", - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "test", Email: "test@test.com", @@ -587,10 +593,12 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch groups when forceUseGraphAPI is set", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "force_use_graph_api": "true", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + Extra: map[string]string{ + "force_use_graph_api": "true", + }, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -607,7 +615,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { Groups: []string{"foo", "bar"}, // must be ignored }, settingAutoAssignOrgRole: "", - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "test", Email: "test@test.com", @@ -620,10 +628,10 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch empty role when strict attribute role is true and no match", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "role_attribute_strict": "true", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + RoleAttributeStrict: true, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -643,10 +651,10 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch empty role when strict attribute role is true and no role claims returned", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "role_attribute_strict": "true", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + RoleAttributeStrict: true, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", @@ -699,8 +707,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := NewAzureADProvider(tt.fields.providerCfg, tt.fields.cfg, featuremgmt.WithFeatures(), cache) - require.NoError(t, err) + s := NewAzureADProvider(tt.fields.providerCfg, tt.fields.cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), cache) if tt.fields.usGovURL { s.SocialBase.Endpoint.AuthURL = usGovAuthURL @@ -767,7 +774,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { func TestSocialAzureAD_SkipOrgRole(t *testing.T) { type fields struct { SocialBase *SocialBase - providerCfg map[string]any + providerCfg *social.OAuthInfo cfg *setting.Cfg } @@ -776,21 +783,23 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { fields fields claims *azureClaims settingAutoAssignOrgRole string - want *BasicUserInfo + want *social.BasicUserInfo wantErr bool }{ { name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should get roles, skipOrgRoleSyncBase disabled", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": "true", - "skip_org_role_sync": "false", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: true, + // TODO: use this setting when SkipOrgRoleSync has moved to OAuthInfo + //SkipOrgRoleSync: false, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", OAuthSkipOrgRoleUpdateSync: false, + AzureADSkipOrgRoleSync: false, }, }, claims: &azureClaims{ @@ -800,7 +809,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -813,15 +822,17 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { { name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should not get roles", fields: fields{ - providerCfg: map[string]any{ - "name": "azuread", - "client_id": "client-id-example", - "allow_assign_grafana_admin": "true", - "skip_org_role_sync": "false", + providerCfg: &social.OAuthInfo{ + Name: "azuread", + ClientId: "client-id-example", + AllowAssignGrafanaAdmin: true, + // TODO: use this setting when SkipOrgRoleSync has moved to OAuthInfo + // SkipOrgRoleSync: false, }, cfg: &setting.Cfg{ AutoAssignOrgRole: "", OAuthSkipOrgRoleUpdateSync: false, + AzureADSkipOrgRoleSync: false, }, }, claims: &azureClaims{ @@ -831,7 +842,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { Name: "My Name", ID: "1234", }, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1234", Name: "My Name", Email: "me@example.com", @@ -875,8 +886,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := NewAzureADProvider(tt.fields.providerCfg, tt.fields.cfg, featuremgmt.WithFeatures(), cache) - require.NoError(t, err) + s := NewAzureADProvider(tt.fields.providerCfg, tt.fields.cfg, &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), cache) s.SocialBase.Endpoint.AuthURL = authURL @@ -943,13 +953,15 @@ func TestSocialAzureAD_InitializeExtraFields(t *testing.T) { } testCases := []struct { name string - settings map[string]any + settings *social.OAuthInfo want settingFields }{ { name: "forceUseGraphAPI is set to true", - settings: map[string]any{ - "force_use_graph_api": "true", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "force_use_graph_api": "true", + }, }, want: settingFields{ forceUseGraphAPI: true, @@ -958,8 +970,10 @@ func TestSocialAzureAD_InitializeExtraFields(t *testing.T) { }, { name: "allowedOrganizations is set", - settings: map[string]any{ - "allowed_organizations": "uuid-1234,uuid-5678", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, }, want: settingFields{ forceUseGraphAPI: false, @@ -970,8 +984,7 @@ func TestSocialAzureAD_InitializeExtraFields(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s, err := NewAzureADProvider(tc.settings, &setting.Cfg{}, featuremgmt.WithFeatures(), nil) - require.NoError(t, err) + s := NewAzureADProvider(tc.settings, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures(), nil) require.Equal(t, tc.want.forceUseGraphAPI, s.forceUseGraphAPI) require.Equal(t, tc.want.allowedOrganizations, s.allowedOrganizations) diff --git a/pkg/login/social/common.go b/pkg/login/social/connectors/common.go similarity index 87% rename from pkg/login/social/common.go rename to pkg/login/social/connectors/common.go index 3d40ee19ee9..a4452080c9e 100644 --- a/pkg/login/social/common.go +++ b/pkg/login/social/connectors/common.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -8,15 +8,17 @@ import ( "io" "net/http" "reflect" + "slices" "strconv" "strings" - "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" "github.com/jmespath/go-jmespath" "github.com/mitchellh/mapstructure" "golang.org/x/oauth2" - "gopkg.in/ini.v1" + + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" ) const ( @@ -27,7 +29,7 @@ const ( ) var ( - errMissingGroupMembership = &Error{"user not a member of one of the required groups"} + errMissingGroupMembership = &SocialError{"user not a member of one of the required groups"} ) type httpGetResponse struct { @@ -147,7 +149,7 @@ func (s *SocialBase) searchJSONForStringArrayAttr(attributePath string, data []b return result, nil } -func createOAuthConfig(info *OAuthInfo, cfg *setting.Cfg, defaultName string) *oauth2.Config { +func createOAuthConfig(info *social.OAuthInfo, cfg *setting.Cfg, defaultName string) *oauth2.Config { var authStyle oauth2.AuthStyle switch strings.ToLower(info.AuthStyle) { case "inparams": @@ -166,7 +168,7 @@ func createOAuthConfig(info *OAuthInfo, cfg *setting.Cfg, defaultName string) *o TokenURL: info.TokenUrl, AuthStyle: authStyle, }, - RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + SocialBaseUrl + defaultName, + RedirectURL: strings.TrimSuffix(cfg.AppURL, "/") + social.SocialBaseUrl + defaultName, Scopes: info.Scopes, } @@ -195,18 +197,9 @@ func MustBool(value any, defaultValue bool) bool { return result } -// convertIniSectionToMap converts key value pairs from an ini section to a map[string]any -func convertIniSectionToMap(sec *ini.Section) map[string]any { - mappedSettings := make(map[string]any) - for k, v := range sec.KeysHash() { - mappedSettings[k] = v - } - return mappedSettings -} - // CreateOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure // it puts all extra key values into OAuthInfo's Extra map -func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) { +func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*social.OAuthInfo, error) { emptyStrToSliceDecodeHook := func(from reflect.Type, to reflect.Type, data any) (any, error) { if from.Kind() == reflect.String && to.Kind() == reflect.Slice { strData, ok := data.(string) @@ -222,7 +215,7 @@ func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) return data, nil } - var oauthInfo OAuthInfo + var oauthInfo social.OAuthInfo decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: emptyStrToSliceDecodeHook, Result: &oauthInfo, @@ -244,3 +237,9 @@ func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) return &oauthInfo, err } + +func appendUniqueScope(config *oauth2.Config, scope string) { + if !slices.Contains(config.Scopes, social.OfflineAccessScope) { + config.Scopes = append(config.Scopes, social.OfflineAccessScope) + } +} diff --git a/pkg/login/social/errors.go b/pkg/login/social/connectors/errors.go similarity index 76% rename from pkg/login/social/errors.go rename to pkg/login/social/connectors/errors.go index 187c8ec8c80..0673d0547de 100644 --- a/pkg/login/social/errors.go +++ b/pkg/login/social/connectors/errors.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "errors" @@ -19,3 +19,12 @@ var ( errInvalidRole = errutil.BadRequest("oauth.invalid_role", errutil.WithPublicMessage("IdP did not return a valid role attribute, please contact your administrator")) ) + +// SocialError is a custom error type for social connectors to provide a public message when the connector expectaions are not met. +type SocialError struct { + s string +} + +func (e SocialError) Error() string { + return e.s +} diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/connectors/generic_oauth.go similarity index 92% rename from pkg/login/social/generic_oauth.go rename to pkg/login/social/connectors/generic_oauth.go index 29f4f28bb43..dfbf80fc4f3 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/connectors/generic_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "bytes" @@ -12,14 +12,15 @@ import ( "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) const ( - GenericOAuthProviderName = "generic_oauth" - nameAttributePathKey = "name_attribute_path" loginAttributePathKey = "login_attribute_path" idTokenAttributeNameKey = "id_token_attribute_name" // #nosec G101 not a hardcoded credential @@ -27,6 +28,9 @@ const ( var ExtraGenericOAuthSettingKeys = []string{nameAttributePathKey, loginAttributePathKey, idTokenAttributeNameKey, teamIdsKey, allowedOrganizationsKey} +var _ social.SocialConnector = (*SocialGenericOAuth)(nil) +var _ ssosettings.Reloadable = (*SocialGenericOAuth)(nil) + type SocialGenericOAuth struct { *SocialBase allowedOrganizations []string @@ -44,15 +48,10 @@ type SocialGenericOAuth struct { skipOrgRoleSync bool } -func NewGenericOAuthProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGenericOAuth, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - - config := createOAuthConfig(info, cfg, GenericOAuthProviderName) +func NewGenericOAuthProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialGenericOAuth { + config := createOAuthConfig(info, cfg, social.GenericOAuthProviderName) provider := &SocialGenericOAuth{ - SocialBase: newSocialBase(GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.GenericOAuthProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, teamsUrl: info.TeamsUrl, emailAttributeName: info.EmailAttributeName, @@ -70,7 +69,19 @@ func NewGenericOAuthProvider(settings map[string]any, cfg *setting.Cfg, features // skipOrgRoleSync: info.SkipOrgRoleSync } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.GenericOAuthProviderName, provider) + } + + return provider +} + +func (s *SocialGenericOAuth) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGenericOAuth) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil } // TODOD: remove this in the next PR and use the isGroupMember from social.go @@ -151,7 +162,7 @@ func (info *UserInfoJson) String() string { info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes) } -func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { s.log.Debug("Getting user info") toCheck := make([]*UserInfoJson, 0, 2) @@ -162,7 +173,7 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, toCheck = append(toCheck, apiData) } - userInfo := &BasicUserInfo{} + userInfo := &social.BasicUserInfo{} for _, data := range toCheck { s.log.Debug("Processing external user info", "source", data.source, "data", data) @@ -249,7 +260,7 @@ func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, return userInfo, nil } -func (s *SocialGenericOAuth) GetOAuthInfo() *OAuthInfo { +func (s *SocialGenericOAuth) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/generic_oauth_test.go b/pkg/login/social/connectors/generic_oauth_test.go similarity index 93% rename from pkg/login/social/generic_oauth_test.go rename to pkg/login/social/connectors/generic_oauth_test.go index 8157636ffa8..e90c60e4355 100644 --- a/pkg/login/social/generic_oauth_test.go +++ b/pkg/login/social/connectors/generic_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -12,16 +12,16 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" - + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) func TestSearchJSONForEmail(t *testing.T) { t.Run("Given a generic OAuth provider", func(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(social.NewOAuthInfo(), &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -105,8 +105,7 @@ func TestSearchJSONForEmail(t *testing.T) { func TestSearchJSONForGroups(t *testing.T) { t.Run("Given a generic OAuth provider", func(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(social.NewOAuthInfo(), &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -165,8 +164,7 @@ func TestSearchJSONForGroups(t *testing.T) { func TestSearchJSONForRole(t *testing.T) { t.Run("Given a generic OAuth provider", func(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(social.NewOAuthInfo(), &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -224,10 +222,11 @@ func TestSearchJSONForRole(t *testing.T) { } func TestUserInfoSearchesForEmailAndRole(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{ - "email_attribute_path": "email", - }, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(&social.OAuthInfo{ + EmailAttributePath: "email", + }, &setting.Cfg{}, + &ssosettingstests.MockService{}, + featuremgmt.WithFeatures()) tests := []struct { Name string @@ -492,10 +491,11 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) { func TestUserInfoSearchesForLogin(t *testing.T) { t.Run("Given a generic OAuth provider", func(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{ - "login_attribute_path": "login", - }, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(&social.OAuthInfo{ + Extra: map[string]string{ + "login_attribute_path": "login", + }, + }, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -585,10 +585,11 @@ func TestUserInfoSearchesForLogin(t *testing.T) { func TestUserInfoSearchesForName(t *testing.T) { t.Run("Given a generic OAuth provider", func(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{ - "name_attribute_path": "name", - }, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(&social.OAuthInfo{ + Extra: map[string]string{ + "name_attribute_path": "name", + }, + }, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -726,11 +727,10 @@ func TestUserInfoSearchesForGroup(t *testing.T) { require.NoError(t, err) })) - provider, err := NewGenericOAuthProvider(map[string]any{ - "groups_attribute_path": test.groupsAttributePath, - "api_url": ts.URL, - }, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(&social.OAuthInfo{ + GroupsAttributePath: test.groupsAttributePath, + ApiUrl: ts.URL, + }, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) token := &oauth2.Token{ AccessToken: "", @@ -748,10 +748,9 @@ func TestUserInfoSearchesForGroup(t *testing.T) { } func TestPayloadCompression(t *testing.T) { - provider, err := NewGenericOAuthProvider(map[string]any{ - "email_attribute_path": "email", - }, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGenericOAuthProvider(&social.OAuthInfo{ + EmailAttributePath: "email", + }, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) tests := []struct { Name string @@ -824,13 +823,15 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { } testCases := []struct { name string - settings map[string]any + settings *social.OAuthInfo want settingFields }{ { name: "nameAttributePath is set", - settings: map[string]any{ - "name_attribute_path": "name", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "name_attribute_path": "name", + }, }, want: settingFields{ nameAttributePath: "name", @@ -842,8 +843,10 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { }, { name: "loginAttributePath is set", - settings: map[string]any{ - "login_attribute_path": "login", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "login_attribute_path": "login", + }, }, want: settingFields{ nameAttributePath: "", @@ -855,8 +858,10 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { }, { name: "idTokenAttributeName is set", - settings: map[string]any{ - "id_token_attribute_name": "id_token", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "id_token_attribute_name": "id_token", + }, }, want: settingFields{ nameAttributePath: "", @@ -868,8 +873,10 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { }, { name: "teamIds is set", - settings: map[string]any{ - "team_ids": "[\"team1\", \"team2\"]", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "team_ids": "[\"team1\", \"team2\"]", + }, }, want: settingFields{ nameAttributePath: "", @@ -881,8 +888,10 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { }, { name: "allowedOrganizations is set", - settings: map[string]any{ - "allowed_organizations": "org1, org2", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "allowed_organizations": "org1, org2", + }, }, want: settingFields{ nameAttributePath: "", @@ -896,8 +905,7 @@ func TestSocialGenericOAuth_InitializeExtraFields(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s, err := NewGenericOAuthProvider(tc.settings, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + s := NewGenericOAuthProvider(tc.settings, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) require.Equal(t, tc.want.nameAttributePath, s.nameAttributePath) require.Equal(t, tc.want.loginAttributePath, s.loginAttributePath) diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/connectors/github_oauth.go similarity index 86% rename from pkg/login/social/github_oauth.go rename to pkg/login/social/connectors/github_oauth.go index 95848d21fc5..881952ce7a7 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/connectors/github_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -12,17 +12,21 @@ import ( "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util/errutil" ) -const GitHubProviderName = "github" - var ExtraGithubSettingKeys = []string{allowedOrganizationsKey, teamIdsKey} +var _ social.SocialConnector = (*SocialGithub)(nil) +var _ ssosettings.Reloadable = (*SocialGithub)(nil) + type SocialGithub struct { *SocialBase allowedOrganizations []string @@ -51,17 +55,12 @@ var ( "User is not a member of one of the required organizations. Please contact identity provider administrator.")) ) -func NewGitHubProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGithub, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - +func NewGitHubProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialGithub { teamIds := mustInts(util.SplitString(info.Extra[teamIdsKey])) - config := createOAuthConfig(info, cfg, GitHubProviderName) + config := createOAuthConfig(info, cfg, social.GitHubProviderName) provider := &SocialGithub{ - SocialBase: newSocialBase(GitHubProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.GitHubProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, teamIds: teamIds, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), @@ -70,7 +69,19 @@ func NewGitHubProvider(settings map[string]any, cfg *setting.Cfg, features *feat // skipOrgRoleSync: info.SkipOrgRoleSync } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.GitHubProviderName, provider) + } + + return provider +} + +func (s *SocialGithub) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGithub) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil } func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool { @@ -220,7 +231,7 @@ func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Clie return logins, nil } -func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { var data struct { Id int `json:"id"` Login string `json:"login"` @@ -264,7 +275,7 @@ func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } - userInfo := &BasicUserInfo{ + userInfo := &social.BasicUserInfo{ Name: data.Login, Login: data.Login, Id: fmt.Sprintf("%d", data.Id), @@ -306,7 +317,7 @@ func (t *GithubTeam) GetShorthand() (string, error) { return fmt.Sprintf("@%s/%s", t.Organization.Login, t.Slug), nil } -func (s *SocialGithub) GetOAuthInfo() *OAuthInfo { +func (s *SocialGithub) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/github_oauth_test.go b/pkg/login/social/connectors/github_oauth_test.go similarity index 86% rename from pkg/login/social/github_oauth_test.go rename to pkg/login/social/connectors/github_oauth_test.go index 1326c6cde12..ccfff095e56 100644 --- a/pkg/login/social/github_oauth_test.go +++ b/pkg/login/social/connectors/github_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -11,7 +11,9 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) @@ -124,7 +126,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { settingSkipOrgRoleSync bool roleAttributePath string autoAssignOrgRole string - want *BasicUserInfo + want *social.BasicUserInfo wantErr bool }{ { @@ -133,7 +135,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { userTeamsRawJSON: testGHUserTeamsJSON, autoAssignOrgRole: "", roleAttributePath: "", - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -148,7 +150,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { userRawJSON: testGHUserJSON, autoAssignOrgRole: "Editor", userTeamsRawJSON: testGHUserTeamsJSON, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -163,7 +165,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { userRawJSON: testGHUserJSON, autoAssignOrgRole: "Editor", userTeamsRawJSON: testGHUserTeamsJSON, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -178,7 +180,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { settingSkipOrgRoleSync: true, userRawJSON: testGHUserJSON, userTeamsRawJSON: testGHUserTeamsJSON, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -194,7 +196,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { settingAllowGrafanaAdmin: true, userRawJSON: testGHUserJSON, userTeamsRawJSON: testGHUserTeamsJSON, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -210,7 +212,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { userRawJSON: testGHUserJSON, autoAssignOrgRole: "Editor", userTeamsRawJSON: testGHUserTeamsJSON, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "monalisa octocat", Email: "octocat@github.com", @@ -239,16 +241,19 @@ func TestSocialGitHub_UserInfo(t *testing.T) { })) defer server.Close() - s, err := NewGitHubProvider(map[string]any{ - "allowed_organizations": "", - "api_url": server.URL + "/user", - "team_ids": "", - "role_attribute_path": tt.roleAttributePath, - }, &setting.Cfg{ - AutoAssignOrgRole: tt.autoAssignOrgRole, - GitHubSkipOrgRoleSync: tt.settingSkipOrgRoleSync, - }, featuremgmt.WithFeatures()) - require.NoError(t, err) + s := NewGitHubProvider( + &social.OAuthInfo{ + ApiUrl: server.URL + "/user", + RoleAttributePath: tt.roleAttributePath, + Extra: map[string]string{ + "allowed_organizations": "", + "team_ids": "", + }, + }, &setting.Cfg{ + AutoAssignOrgRole: tt.autoAssignOrgRole, + GitHubSkipOrgRoleSync: tt.settingSkipOrgRoleSync, + }, &ssosettingstests.MockService{}, + featuremgmt.WithFeatures()) token := &oauth2.Token{ AccessToken: "fake_token", @@ -273,13 +278,15 @@ func TestSocialGitHub_InitializeExtraFields(t *testing.T) { } testCases := []struct { name string - settings map[string]any + settings *social.OAuthInfo want settingFields }{ { name: "teamIds is set", - settings: map[string]any{ - "team_ids": "1234,5678", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "team_ids": "1234,5678", + }, }, want: settingFields{ teamIds: []int{1234, 5678}, @@ -288,8 +295,10 @@ func TestSocialGitHub_InitializeExtraFields(t *testing.T) { }, { name: "allowedOrganizations is set", - settings: map[string]any{ - "allowed_organizations": "uuid-1234,uuid-5678", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, }, want: settingFields{ teamIds: []int{}, @@ -298,9 +307,11 @@ func TestSocialGitHub_InitializeExtraFields(t *testing.T) { }, { name: "teamIds and allowedOrganizations are empty", - settings: map[string]any{ - "team_ids": "", - "allowed_organizations": "", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "team_ids": "", + "allowed_organizations": "", + }, }, want: settingFields{ teamIds: []int{}, @@ -309,8 +320,10 @@ func TestSocialGitHub_InitializeExtraFields(t *testing.T) { }, { name: "should not error when teamIds are not integers", - settings: map[string]any{ - "team_ids": "abc1234,5678", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "team_ids": "abc1234,5678", + }, }, want: settingFields{ teamIds: []int{}, @@ -321,8 +334,7 @@ func TestSocialGitHub_InitializeExtraFields(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s, err := NewGitHubProvider(tc.settings, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + s := NewGitHubProvider(tc.settings, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) require.Equal(t, tc.want.teamIds, s.teamIds) require.Equal(t, tc.want.allowedOrganizations, s.allowedOrganizations) diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/connectors/gitlab_oauth.go similarity index 86% rename from pkg/login/social/gitlab_oauth.go rename to pkg/login/social/connectors/gitlab_oauth.go index bee41ccc693..9dcbe4830fd 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/connectors/gitlab_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -11,17 +11,22 @@ import ( "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" ) const ( - groupPerPage = 50 - accessLevelGuest = "10" - GitlabProviderName = "gitlab" + groupPerPage = 50 + accessLevelGuest = "10" ) +var _ social.SocialConnector = (*SocialGitlab)(nil) +var _ ssosettings.Reloadable = (*SocialGitlab)(nil) + type SocialGitlab struct { *SocialBase apiUrl string @@ -49,22 +54,29 @@ type userData struct { IsGrafanaAdmin *bool `json:"-"` } -func NewGitLabProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGitlab, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - - config := createOAuthConfig(info, cfg, GitlabProviderName) +func NewGitLabProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialGitlab { + config := createOAuthConfig(info, cfg, social.GitlabProviderName) provider := &SocialGitlab{ - SocialBase: newSocialBase(GitlabProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.GitlabProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, skipOrgRoleSync: cfg.GitLabSkipOrgRoleSync, // FIXME: Move skipOrgRoleSync to OAuthInfo // skipOrgRoleSync: info.SkipOrgRoleSync } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.GitlabProviderName, provider) + } + + return provider +} + +func (s *SocialGitlab) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGitlab) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil } func (s *SocialGitlab) getGroups(ctx context.Context, client *http.Client) []string { @@ -146,7 +158,7 @@ func (s *SocialGitlab) getGroupsPage(ctx context.Context, client *http.Client, n return fullPaths, next } -func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { data, err := s.extractFromToken(ctx, client, token) if err != nil { return nil, err @@ -161,7 +173,7 @@ func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token } } - userInfo := &BasicUserInfo{ + userInfo := &social.BasicUserInfo{ Id: data.ID, Name: data.Name, Login: data.Login, @@ -182,7 +194,7 @@ func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, token return userInfo, nil } -func (s *SocialGitlab) GetOAuthInfo() *OAuthInfo { +func (s *SocialGitlab) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/gitlab_oauth_test.go b/pkg/login/social/connectors/gitlab_oauth_test.go similarity index 94% rename from pkg/login/social/gitlab_oauth_test.go rename to pkg/login/social/connectors/gitlab_oauth_test.go index 6b20fb7a14d..b37040320f7 100644 --- a/pkg/login/social/gitlab_oauth_test.go +++ b/pkg/login/social/connectors/gitlab_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -15,8 +15,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) @@ -39,8 +41,7 @@ const ( func TestSocialGitlab_UserInfo(t *testing.T) { var nilPointer *bool - provider, err := NewGitLabProvider(map[string]any{"skip_org_role_sync": false}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGitLabProvider(&social.OAuthInfo{SkipOrgRoleSync: false}, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) type conf struct { AllowAssignGrafanaAdmin bool @@ -346,21 +347,23 @@ func TestSocialGitlab_extractFromToken(t *testing.T) { // Create a test client with a dummy token client := oauth2.NewClient(context.Background(), &tokenSource{accessToken: "dummy_access_token"}) - s, err := NewGitLabProvider(map[string]any{ - "allowed_domains": []string{}, - "allow_sign_up": false, - "role_attribute_path": "", - "role_attribute_strict": false, - "skip_org_role_sync": false, - "auth_url": tc.config.Endpoint.AuthURL, - "token_url": tc.config.Endpoint.TokenURL, - }, + s := NewGitLabProvider( + &social.OAuthInfo{ + AllowedDomains: []string{}, + AllowSignup: false, + RoleAttributePath: "", + RoleAttributeStrict: false, + // TODO: use this setting when SkipOrgRoleSync has moved to OAuthInfo + //SkipOrgRoleSync: false, + AuthUrl: tc.config.Endpoint.AuthURL, + TokenUrl: tc.config.Endpoint.TokenURL, + }, &setting.Cfg{ AutoAssignOrgRole: "", OAuthSkipOrgRoleUpdateSync: false, - }, featuremgmt.WithFeatures()) - - require.NoError(t, err) + GitLabSkipOrgRoleSync: false, + }, &ssosettingstests.MockService{}, + featuremgmt.WithFeatures()) // Test case: successful extraction token := &oauth2.Token{} @@ -450,8 +453,7 @@ func TestSocialGitlab_GetGroupsNextPage(t *testing.T) { defer mockServer.Close() // Create a SocialGitlab instance with the mock server URL - s, err := NewGitLabProvider(map[string]any{"api_url": mockServer.URL}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + s := NewGitLabProvider(&social.OAuthInfo{ApiUrl: mockServer.URL}, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) // Call getGroups and verify that it returns all groups expectedGroups := []string{"admins", "editors", "viewers", "serveradmins"} diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/connectors/google_oauth.go similarity index 84% rename from pkg/login/social/google_oauth.go rename to pkg/login/social/connectors/google_oauth.go index 4164c267a3f..69d14f27be6 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/connectors/google_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -10,7 +10,10 @@ import ( "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" ) @@ -18,9 +21,11 @@ const ( legacyAPIURL = "https://www.googleapis.com/oauth2/v1/userinfo" googleIAMGroupsEndpoint = "https://content-cloudidentity.googleapis.com/v1/groups/-/memberships:searchDirectGroups" googleIAMScope = "https://www.googleapis.com/auth/cloud-identity.groups.readonly" - GoogleProviderName = "google" ) +var _ social.SocialConnector = (*SocialGoogle)(nil) +var _ ssosettings.Reloadable = (*SocialGoogle)(nil) + type SocialGoogle struct { *SocialBase hostedDomain string @@ -36,15 +41,10 @@ type googleUserData struct { rawJSON []byte `json:"-"` } -func NewGoogleProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGoogle, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - - config := createOAuthConfig(info, cfg, GoogleProviderName) +func NewGoogleProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialGoogle { + config := createOAuthConfig(info, cfg, social.GoogleProviderName) provider := &SocialGoogle{ - SocialBase: newSocialBase(GoogleProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.GoogleProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), hostedDomain: info.HostedDomain, apiUrl: info.ApiUrl, skipOrgRoleSync: cfg.GoogleSkipOrgRoleSync, @@ -56,10 +56,22 @@ func NewGoogleProvider(settings map[string]any, cfg *setting.Cfg, features *feat provider.log.Warn("Using legacy Google API URL, please update your configuration") } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.GoogleProviderName, provider) + } + + return provider } -func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGoogle) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGoogle) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { data, errToken := s.extractFromToken(ctx, client, token) if errToken != nil { return nil, errToken @@ -90,7 +102,7 @@ func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token return nil, errMissingGroupMembership } - userInfo := &BasicUserInfo{ + userInfo := &social.BasicUserInfo{ Id: data.ID, Name: data.Name, Email: data.Email, @@ -118,7 +130,7 @@ func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token return userInfo, nil } -func (s *SocialGoogle) GetOAuthInfo() *OAuthInfo { +func (s *SocialGoogle) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/google_oauth_test.go b/pkg/login/social/connectors/google_oauth_test.go similarity index 92% rename from pkg/login/social/google_oauth_test.go rename to pkg/login/social/connectors/google_oauth_test.go index b5785bee133..cde6ecc8666 100644 --- a/pkg/login/social/google_oauth_test.go +++ b/pkg/login/social/connectors/google_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -14,8 +14,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) @@ -181,22 +183,23 @@ func TestSocialGoogle_retrieveGroups(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := NewGoogleProvider(map[string]any{ - "api_url": "", - "scopes": tt.fields.Scopes, - "hosted_domain": "", - "allowed_domains": []string{}, - "allow_sign_up": false, - "role_attribute_path": "", - "role_attribute_strict": false, - "allow_assign_grafana_admin": false, - }, + s := NewGoogleProvider( + &social.OAuthInfo{ + ApiUrl: "", + Scopes: tt.fields.Scopes, + HostedDomain: "", + AllowedDomains: []string{}, + AllowSignup: false, + RoleAttributePath: "", + RoleAttributeStrict: false, + AllowAssignGrafanaAdmin: false, + }, &setting.Cfg{ AutoAssignOrgRole: "", GoogleSkipOrgRoleSync: false, }, + &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) - require.NoError(t, err) got, err := s.retrieveGroups(context.Background(), tt.args.client, tt.args.userData) if (err != nil) != tt.wantErr { @@ -259,7 +262,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { name string fields fields args args - wantData *BasicUserInfo + wantData *social.BasicUserInfo wantErr bool wantErrMsg string }{ @@ -272,7 +275,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { args: args{ token: tokenWithID, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -309,7 +312,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { }, }, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -341,7 +344,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { }, }, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "99999999999999", Login: "test@example.com", Email: "test@example.com", @@ -459,7 +462,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { }, }, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "92222222222222222", Name: "Test User", Email: "test@example.com", @@ -521,7 +524,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { }, }, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -542,7 +545,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { args: args{ token: tokenWithID, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -562,7 +565,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { args: args{ token: tokenWithID, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -582,7 +585,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { args: args{ token: tokenWithID, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -621,7 +624,7 @@ func TestSocialGoogle_UserInfo(t *testing.T) { }, }, }, - wantData: &BasicUserInfo{ + wantData: &social.BasicUserInfo{ Id: "88888888888888", Login: "test@example.com", Email: "test@example.com", @@ -635,20 +638,23 @@ func TestSocialGoogle_UserInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := NewGoogleProvider(map[string]any{ - "api_url": tt.fields.apiURL, - "scopes": tt.fields.Scopes, - "allowed_groups": tt.fields.allowedGroups, - "allow_sign_up": false, - "role_attribute_path": tt.fields.roleAttributePath, - "role_attribute_strict": tt.fields.roleAttributeStrict, - "allow_assign_grafana_admin": tt.fields.allowAssignGrafanaAdmin, - }, + s := NewGoogleProvider( + &social.OAuthInfo{ + ApiUrl: tt.fields.apiURL, + Scopes: tt.fields.Scopes, + AllowedGroups: tt.fields.allowedGroups, + AllowSignup: false, + RoleAttributePath: tt.fields.roleAttributePath, + RoleAttributeStrict: tt.fields.roleAttributeStrict, + AllowAssignGrafanaAdmin: tt.fields.allowAssignGrafanaAdmin, + // TODO: use this setting when SkipOrgRoleSync has moved to OAuthInfo + // SkipOrgRoleSync: tt.fields.skipOrgRoleSync, + }, &setting.Cfg{ GoogleSkipOrgRoleSync: tt.fields.skipOrgRoleSync, }, + &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) - require.NoError(t, err) gotData, err := s.UserInfo(context.Background(), tt.args.client, tt.args.token) if tt.wantErr { diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/connectors/grafana_com_oauth.go similarity index 68% rename from pkg/login/social/grafana_com_oauth.go rename to pkg/login/social/connectors/grafana_com_oauth.go index 6dce7b1d85b..3d6071e78d1 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/connectors/grafana_com_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -8,21 +8,21 @@ import ( "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) -const ( - GrafanaComProviderName = "grafana_com" - // legacy/old settings for the provider - GrafanaNetProviderName = "grafananet" -) - var ExtraGrafanaComSettingKeys = []string{allowedOrganizationsKey} +var _ social.SocialConnector = (*SocialGrafanaCom)(nil) +var _ ssosettings.Reloadable = (*SocialGrafanaCom)(nil) + type SocialGrafanaCom struct { *SocialBase url string @@ -34,20 +34,15 @@ type OrgRecord struct { Login string `json:"login"` } -func NewGrafanaComProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGrafanaCom, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - +func NewGrafanaComProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialGrafanaCom { // Override necessary settings info.AuthUrl = cfg.GrafanaComURL + "/oauth2/authorize" info.TokenUrl = cfg.GrafanaComURL + "/api/oauth2/token" info.AuthStyle = "inheader" - config := createOAuthConfig(info, cfg, GrafanaComProviderName) + config := createOAuthConfig(info, cfg, social.GrafanaComProviderName) provider := &SocialGrafanaCom{ - SocialBase: newSocialBase(GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.GrafanaComProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), url: cfg.GrafanaComURL, allowedOrganizations: util.SplitString(info.Extra[allowedOrganizationsKey]), skipOrgRoleSync: cfg.GrafanaComSkipOrgRoleSync, @@ -55,7 +50,19 @@ func NewGrafanaComProvider(settings map[string]any, cfg *setting.Cfg, features * // skipOrgRoleSync: info.SkipOrgRoleSync } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.GrafanaComProviderName, provider) + } + + return provider +} + +func (s *SocialGrafanaCom) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialGrafanaCom) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil } func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { @@ -79,7 +86,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool } // UserInfo is used for login credentials for the user -func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*social.BasicUserInfo, error) { var data struct { Id int `json:"id"` Name string `json:"name"` @@ -105,7 +112,7 @@ func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ if !s.skipOrgRoleSync { role = org.RoleType(data.Role) } - userInfo := &BasicUserInfo{ + userInfo := &social.BasicUserInfo{ Id: fmt.Sprintf("%d", data.Id), Name: data.Name, Login: data.Login, @@ -122,6 +129,6 @@ func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ return userInfo, nil } -func (s *SocialGrafanaCom) GetOAuthInfo() *OAuthInfo { +func (s *SocialGrafanaCom) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/grafana_com_oauth_test.go b/pkg/login/social/connectors/grafana_com_oauth_test.go similarity index 80% rename from pkg/login/social/grafana_com_oauth_test.go rename to pkg/login/social/connectors/grafana_com_oauth_test.go index a937cb51a9b..35629547f00 100644 --- a/pkg/login/social/grafana_com_oauth_test.go +++ b/pkg/login/social/connectors/grafana_com_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -6,9 +6,12 @@ import ( "net/http/httptest" "testing" - "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" + "github.com/grafana/grafana/pkg/setting" ) const ( @@ -25,8 +28,7 @@ const ( ) func TestSocialGrafanaCom_UserInfo(t *testing.T) { - provider, err := NewGrafanaComProvider(map[string]any{}, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + provider := NewGrafanaComProvider(social.NewOAuthInfo(), &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) type conf struct { skipOrgRoleSync bool @@ -36,14 +38,14 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) { Name string Cfg conf userInfoResp string - want *BasicUserInfo + want *social.BasicUserInfo ExpectedError error }{ { Name: "should return empty role as userInfo when Skip Org Role Sync Enabled", userInfoResp: userResponse, Cfg: conf{skipOrgRoleSync: true}, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "Eric Leijonmarck", Email: "octocat@github.com", @@ -55,7 +57,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) { Name: "should return role as userInfo when Skip Org Role Sync Enabled", userInfoResp: userResponse, Cfg: conf{skipOrgRoleSync: false}, - want: &BasicUserInfo{ + want: &social.BasicUserInfo{ Id: "1", Name: "Eric Leijonmarck", Email: "octocat@github.com", @@ -99,20 +101,22 @@ func TestSocialGrafanaCom_InitializeExtraFields(t *testing.T) { } testCases := []struct { name string - settings map[string]any + settings *social.OAuthInfo want settingFields }{ { name: "allowedOrganizations is not set", - settings: map[string]any{}, + settings: social.NewOAuthInfo(), want: settingFields{ allowedOrganizations: []string{}, }, }, { name: "allowedOrganizations is set", - settings: map[string]any{ - "allowed_organizations": "uuid-1234,uuid-5678", + settings: &social.OAuthInfo{ + Extra: map[string]string{ + "allowed_organizations": "uuid-1234,uuid-5678", + }, }, want: settingFields{ allowedOrganizations: []string{"uuid-1234", "uuid-5678"}, @@ -122,8 +126,7 @@ func TestSocialGrafanaCom_InitializeExtraFields(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s, err := NewGrafanaComProvider(tc.settings, &setting.Cfg{}, featuremgmt.WithFeatures()) - require.NoError(t, err) + s := NewGrafanaComProvider(tc.settings, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) require.Equal(t, tc.want.allowedOrganizations, s.allowedOrganizations) }) diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/connectors/okta_oauth.go similarity index 78% rename from pkg/login/social/okta_oauth.go rename to pkg/login/social/connectors/okta_oauth.go index fe2d41c0a4d..1ba9b1f892e 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/connectors/okta_oauth.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -10,12 +10,16 @@ import ( "github.com/go-jose/go-jose/v3/jwt" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/setting" ) -const OktaProviderName = "okta" +var _ social.SocialConnector = (*SocialOkta)(nil) +var _ ssosettings.Reloadable = (*SocialOkta)(nil) type SocialOkta struct { *SocialBase @@ -43,15 +47,10 @@ type OktaClaims struct { Name string `json:"name"` } -func NewOktaProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialOkta, error) { - info, err := CreateOAuthInfoFromKeyValues(settings) - if err != nil { - return nil, err - } - - config := createOAuthConfig(info, cfg, OktaProviderName) +func NewOktaProvider(info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager) *SocialOkta { + config := createOAuthConfig(info, cfg, social.OktaProviderName) provider := &SocialOkta{ - SocialBase: newSocialBase(OktaProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), + SocialBase: newSocialBase(social.OktaProviderName, config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, allowedGroups: info.AllowedGroups, // FIXME: Move skipOrgRoleSync to OAuthInfo @@ -60,10 +59,22 @@ func NewOktaProvider(settings map[string]any, cfg *setting.Cfg, features *featur } if info.UseRefreshToken && features.IsEnabledGlobally(featuremgmt.FlagAccessTokenExpirationCheck) { - appendUniqueScope(config, OfflineAccessScope) + appendUniqueScope(config, social.OfflineAccessScope) } - return provider, nil + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + ssoSettings.RegisterReloadable(social.OktaProviderName, provider) + } + + return provider +} + +func (s *SocialOkta) Validate(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil +} + +func (s *SocialOkta) Reload(ctx context.Context, settings ssoModels.SSOSettings) error { + return nil } func (claims *OktaClaims) extractEmail() string { @@ -74,7 +85,7 @@ func (claims *OktaClaims) extractEmail() string { return claims.Email } -func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { return nil, fmt.Errorf("no id_token found") @@ -123,7 +134,7 @@ func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *o s.log.Debug("AllowAssignGrafanaAdmin and skipOrgRoleSync are both set, Grafana Admin role will not be synced, consider setting one or the other") } - return &BasicUserInfo{ + return &social.BasicUserInfo{ Id: claims.ID, Name: claims.Name, Email: email, @@ -134,7 +145,7 @@ func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *o }, nil } -func (s *SocialOkta) GetOAuthInfo() *OAuthInfo { +func (s *SocialOkta) GetOAuthInfo() *social.OAuthInfo { return s.info } diff --git a/pkg/login/social/okta_oauth_test.go b/pkg/login/social/connectors/okta_oauth_test.go similarity index 87% rename from pkg/login/social/okta_oauth_test.go rename to pkg/login/social/connectors/okta_oauth_test.go index 1b0cdff3879..db3ebd199fe 100644 --- a/pkg/login/social/okta_oauth_test.go +++ b/pkg/login/social/connectors/okta_oauth_test.go @@ -1,4 +1,4 @@ -package social +package connectors import ( "context" @@ -12,8 +12,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) @@ -67,7 +69,7 @@ func TestSocialOkta_UserInfo(t *testing.T) { }, { name: "Should give grafanaAdmin role for specific GrafanaAdmin in the role assignement", - userRawJSON: fmt.Sprintf(`{ "email": "okta-octopus@grafana.com", "role": "%s" }`, RoleGrafanaAdmin), + userRawJSON: fmt.Sprintf(`{ "email": "okta-octopus@grafana.com", "role": "%s" }`, social.RoleGrafanaAdmin), RoleAttributePath: "role", allowAssignGrafanaAdmin: true, OAuth2Extra: map[string]any{ @@ -97,20 +99,21 @@ func TestSocialOkta_UserInfo(t *testing.T) { })) defer server.Close() - provider, err := NewOktaProvider( - map[string]any{ - "api_url": server.URL + "/user", - "role_attribute_path": tt.RoleAttributePath, - "allow_assign_grafana_admin": tt.allowAssignGrafanaAdmin, - "skip_org_role_sync": tt.settingSkipOrgRoleSync, + provider := NewOktaProvider( + &social.OAuthInfo{ + ApiUrl: server.URL + "/user", + RoleAttributePath: tt.RoleAttributePath, + AllowAssignGrafanaAdmin: tt.allowAssignGrafanaAdmin, + // TODO: use this setting when SkipOrgRoleSync has moved to OAuthInfo + // SkipOrgRoleSync: tt.settingSkipOrgRoleSync, }, &setting.Cfg{ OktaSkipOrgRoleSync: tt.settingSkipOrgRoleSync, AutoAssignOrgRole: tt.autoAssignOrgRole, OAuthSkipOrgRoleUpdateSync: false, }, + &ssosettingstests.MockService{}, featuremgmt.WithFeatures()) - require.NoError(t, err) // create a oauth2 token with a id_token staticToken := oauth2.Token{ diff --git a/pkg/login/social/connectors/social_base.go b/pkg/login/social/connectors/social_base.go new file mode 100644 index 00000000000..0ace2fb9baf --- /dev/null +++ b/pkg/login/social/connectors/social_base.go @@ -0,0 +1,229 @@ +package connectors + +import ( + "bytes" + "compress/zlib" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "regexp" + "strings" + + "golang.org/x/oauth2" + "golang.org/x/text/cases" + "golang.org/x/text/language" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/org" +) + +type SocialBase struct { + *oauth2.Config + info *social.OAuthInfo + log log.Logger + allowSignup bool + allowAssignGrafanaAdmin bool + allowedDomains []string + allowedGroups []string + + roleAttributePath string + roleAttributeStrict bool + autoAssignOrgRole string + skipOrgRoleSync bool + features featuremgmt.FeatureManager + useRefreshToken bool +} + +func newSocialBase(name string, + config *oauth2.Config, + info *social.OAuthInfo, + autoAssignOrgRole string, + skipOrgRoleSync bool, + features featuremgmt.FeatureManager, +) *SocialBase { + logger := log.New("oauth." + name) + + return &SocialBase{ + Config: config, + info: info, + log: logger, + allowSignup: info.AllowSignup, + allowAssignGrafanaAdmin: info.AllowAssignGrafanaAdmin, + allowedDomains: info.AllowedDomains, + allowedGroups: info.AllowedGroups, + roleAttributePath: info.RoleAttributePath, + roleAttributeStrict: info.RoleAttributeStrict, + autoAssignOrgRole: autoAssignOrgRole, + skipOrgRoleSync: skipOrgRoleSync, + features: features, + useRefreshToken: info.UseRefreshToken, + } +} + +type groupStruct struct { + Groups []string `json:"groups"` +} + +func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error { + bf.WriteString("## Client configuration\n\n") + bf.WriteString("```ini\n") + bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.allowAssignGrafanaAdmin)) + bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.allowSignup)) + bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.allowedDomains)) + bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole)) + bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.roleAttributePath)) + bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.roleAttributeStrict)) + bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.skipOrgRoleSync)) + bf.WriteString(fmt.Sprintf("client_id = %v\n", s.Config.ClientID)) + bf.WriteString(fmt.Sprintf("client_secret = %v ; issue if empty\n", strings.Repeat("*", len(s.Config.ClientSecret)))) + bf.WriteString(fmt.Sprintf("auth_url = %v\n", s.Config.Endpoint.AuthURL)) + bf.WriteString(fmt.Sprintf("token_url = %v\n", s.Config.Endpoint.TokenURL)) + bf.WriteString(fmt.Sprintf("auth_style = %v\n", s.Config.Endpoint.AuthStyle)) + bf.WriteString(fmt.Sprintf("redirect_url = %v\n", s.Config.RedirectURL)) + bf.WriteString(fmt.Sprintf("scopes = %v\n", s.Config.Scopes)) + bf.WriteString("```\n\n") + return nil +} + +func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) { + if s.roleAttributePath == "" { + if s.roleAttributeStrict { + return "", false, errRoleAttributePathNotSet.Errorf("role_attribute_path not set and role_attribute_strict is set") + } + return "", false, nil + } + + if role, gAdmin := s.searchRole(rawJSON, groups); role.IsValid() { + return role, gAdmin, nil + } else if role != "" { + return "", false, errInvalidRole.Errorf("invalid role: %s", role) + } + + if s.roleAttributeStrict { + return "", false, errRoleAttributeStrictViolation.Errorf("idP did not return a role attribute, but role_attribute_strict is set") + } + + return "", false, nil +} + +func (s *SocialBase) extractRoleAndAdmin(rawJSON []byte, groups []string) (org.RoleType, bool, error) { + role, gAdmin, err := s.extractRoleAndAdminOptional(rawJSON, groups) + if role == "" { + role = s.defaultRole() + } + + return role, gAdmin, err +} + +func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType, bool) { + role, err := s.searchJSONForStringAttr(s.roleAttributePath, rawJSON) + if err == nil && role != "" { + return getRoleFromSearch(role) + } + + if groupBytes, err := json.Marshal(groupStruct{groups}); err == nil { + role, err := s.searchJSONForStringAttr(s.roleAttributePath, groupBytes) + if err == nil && role != "" { + return getRoleFromSearch(role) + } + } + + return "", false +} + +// defaultRole returns the default role for the user based on the autoAssignOrgRole setting +// if legacy is enabled "" is returned indicating the previous role assignment is used. +func (s *SocialBase) defaultRole() org.RoleType { + if s.autoAssignOrgRole != "" { + s.log.Debug("No role found, returning default.") + return org.RoleType(s.autoAssignOrgRole) + } + + // should never happen + return org.RoleViewer +} + +func (s *SocialBase) isGroupMember(groups []string) bool { + if len(s.allowedGroups) == 0 { + return true + } + + for _, allowedGroup := range s.allowedGroups { + for _, group := range groups { + if group == allowedGroup { + return true + } + } + } + + return false +} + +func (s *SocialBase) retrieveRawIDToken(idToken any) ([]byte, error) { + tokenString, ok := idToken.(string) + if !ok { + return nil, fmt.Errorf("id_token is not a string: %v", idToken) + } + + jwtRegexp := regexp.MustCompile("^([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)$") + matched := jwtRegexp.FindStringSubmatch(tokenString) + if matched == nil { + return nil, fmt.Errorf("id_token is not in JWT format: %s", tokenString) + } + + rawJSON, err := base64.RawURLEncoding.DecodeString(matched[2]) + if err != nil { + return nil, fmt.Errorf("error base64 decoding id_token: %w", err) + } + + headerBytes, err := base64.RawURLEncoding.DecodeString(matched[1]) + if err != nil { + return nil, fmt.Errorf("error base64 decoding header: %w", err) + } + + var header map[string]any + if err := json.Unmarshal(headerBytes, &header); err != nil { + return nil, fmt.Errorf("error deserializing header: %w", err) + } + + if compressionVal, exists := header["zip"]; exists { + compression, ok := compressionVal.(string) + if !ok { + return nil, fmt.Errorf("unrecognized compression header: %v", compressionVal) + } + + if compression != "DEF" { + return nil, fmt.Errorf("unknown compression algorithm: %s", compression) + } + + fr, err := zlib.NewReader(bytes.NewReader(rawJSON)) + if err != nil { + return nil, fmt.Errorf("error creating zlib reader: %w", err) + } + defer func() { + if err := fr.Close(); err != nil { + s.log.Warn("Failed closing zlib reader", "error", err) + } + }() + + rawJSON, err = io.ReadAll(fr) + if err != nil { + return nil, fmt.Errorf("error decompressing payload: %w", err) + } + } + + return rawJSON, nil +} + +// match grafana admin role and translate to org role and bool. +// treat the JSON search result to ensure correct casing. +func getRoleFromSearch(role string) (org.RoleType, bool) { + if strings.EqualFold(role, social.RoleGrafanaAdmin) { + return org.RoleAdmin, true + } + + return org.RoleType(cases.Title(language.Und).String(role)), false +} diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index a518e00df5a..9ffe00933a4 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -2,45 +2,54 @@ package social import ( "bytes" - "compress/zlib" "context" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" "fmt" - "io" - "net" "net/http" - "os" - "regexp" - "slices" - "strings" - "time" - "golang.org/x/oauth2" - "golang.org/x/text/cases" - "golang.org/x/text/language" - - "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/infra/remotecache" - "github.com/grafana/grafana/pkg/infra/usagestats" - "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/supportbundles" - "github.com/grafana/grafana/pkg/setting" + "golang.org/x/oauth2" ) const ( OfflineAccessScope = "offline_access" RoleGrafanaAdmin = "GrafanaAdmin" // For AzureAD for example this value cannot contain spaces + + AzureADProviderName = "azuread" + GenericOAuthProviderName = "generic_oauth" + GitHubProviderName = "github" + GitlabProviderName = "gitlab" + GoogleProviderName = "google" + GrafanaComProviderName = "grafana_com" + // legacy/old settings for the provider + GrafanaNetProviderName = "grafananet" + OktaProviderName = "okta" ) -type SocialService struct { - cfg *setting.Cfg +var ( + SocialBaseUrl = "/login/" +) - socialMap map[string]SocialConnector - log log.Logger +type Service interface { + GetOAuthProviders() map[string]bool + GetOAuthHttpClient(string) (*http.Client, error) + GetConnector(string) (SocialConnector, error) + GetOAuthInfoProvider(string) *OAuthInfo + GetOAuthInfoProviders() map[string]*OAuthInfo +} + +//go:generate mockery --name SocialConnector --structname MockSocialConnector --outpkg socialtest --filename social_connector_mock.go --output ./socialtest/ +type SocialConnector interface { + UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) + IsEmailAllowed(email string) bool + IsSignupAllowed() bool + + GetOAuthInfo() *OAuthInfo + + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) + Client(ctx context.Context, t *oauth2.Token) *http.Client + TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource + SupportBundleContent(*bytes.Buffer) error } type OAuthInfo struct { @@ -88,51 +97,6 @@ func NewOAuthInfo() *OAuthInfo { } } -func ProvideService(cfg *setting.Cfg, - features *featuremgmt.FeatureManager, - usageStats usagestats.Service, - bundleRegistry supportbundles.Service, - cache remotecache.CacheStorage, -) *SocialService { - ss := &SocialService{ - cfg: cfg, - socialMap: make(map[string]SocialConnector), - log: log.New("login.social"), - } - - usageStats.RegisterMetricsFunc(ss.getUsageStats) - - for _, name := range allOauthes { - sec := cfg.Raw.Section("auth." + name) - - settingsKVs := convertIniSectionToMap(sec) - info, err := CreateOAuthInfoFromKeyValues(settingsKVs) - if err != nil { - ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", name) - continue - } - - if !info.Enabled { - continue - } - - if name == GrafanaNetProviderName { - name = GrafanaComProviderName - } - - conn, err := ss.createOAuthConnector(name, settingsKVs, cfg, features, cache) - if err != nil { - ss.log.Error("Failed to create OAuth provider", "error", err, "provider", name) - } - - ss.socialMap[name] = conn - } - - ss.registerSupportBundleCollectors(bundleRegistry) - - return ss -} - type BasicUserInfo struct { Id string Name string @@ -147,394 +111,3 @@ func (b *BasicUserInfo) String() string { return fmt.Sprintf("Id: %s, Name: %s, Email: %s, Login: %s, Role: %s, Groups: %v", b.Id, b.Name, b.Email, b.Login, b.Role, b.Groups) } - -//go:generate mockery --name SocialConnector --structname MockSocialConnector --outpkg socialtest --filename social_connector_mock.go --output ../socialtest/ -type SocialConnector interface { - UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) - IsEmailAllowed(email string) bool - IsSignupAllowed() bool - - GetOAuthInfo() *OAuthInfo - - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) - Client(ctx context.Context, t *oauth2.Token) *http.Client - TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource - SupportBundleContent(*bytes.Buffer) error -} - -type SocialBase struct { - *oauth2.Config - info *OAuthInfo - log log.Logger - allowSignup bool - allowAssignGrafanaAdmin bool - allowedDomains []string - allowedGroups []string - - roleAttributePath string - roleAttributeStrict bool - autoAssignOrgRole string - skipOrgRoleSync bool - features featuremgmt.FeatureManager - useRefreshToken bool -} - -type Error struct { - s string -} - -func (e Error) Error() string { - return e.s -} - -var ( - SocialBaseUrl = "/login/" - SocialMap = make(map[string]SocialConnector) - allOauthes = []string{GitHubProviderName, GitlabProviderName, GoogleProviderName, GenericOAuthProviderName, GrafanaNetProviderName, - GrafanaComProviderName, AzureADProviderName, OktaProviderName} -) - -type Service interface { - GetOAuthProviders() map[string]bool - GetOAuthHttpClient(string) (*http.Client, error) - GetConnector(string) (SocialConnector, error) - GetOAuthInfoProvider(string) *OAuthInfo - GetOAuthInfoProviders() map[string]*OAuthInfo -} - -func newSocialBase(name string, - config *oauth2.Config, - info *OAuthInfo, - autoAssignOrgRole string, - skipOrgRoleSync bool, - features featuremgmt.FeatureManager, -) *SocialBase { - logger := log.New("oauth." + name) - - return &SocialBase{ - Config: config, - info: info, - log: logger, - allowSignup: info.AllowSignup, - allowAssignGrafanaAdmin: info.AllowAssignGrafanaAdmin, - allowedDomains: info.AllowedDomains, - allowedGroups: info.AllowedGroups, - roleAttributePath: info.RoleAttributePath, - roleAttributeStrict: info.RoleAttributeStrict, - autoAssignOrgRole: autoAssignOrgRole, - skipOrgRoleSync: skipOrgRoleSync, - features: features, - useRefreshToken: info.UseRefreshToken, - } -} - -type groupStruct struct { - Groups []string `json:"groups"` -} - -func (s *SocialBase) SupportBundleContent(bf *bytes.Buffer) error { - bf.WriteString("## Client configuration\n\n") - bf.WriteString("```ini\n") - bf.WriteString(fmt.Sprintf("allow_assign_grafana_admin = %v\n", s.allowAssignGrafanaAdmin)) - bf.WriteString(fmt.Sprintf("allow_sign_up = %v\n", s.allowSignup)) - bf.WriteString(fmt.Sprintf("allowed_domains = %v\n", s.allowedDomains)) - bf.WriteString(fmt.Sprintf("auto_assign_org_role = %v\n", s.autoAssignOrgRole)) - bf.WriteString(fmt.Sprintf("role_attribute_path = %v\n", s.roleAttributePath)) - bf.WriteString(fmt.Sprintf("role_attribute_strict = %v\n", s.roleAttributeStrict)) - bf.WriteString(fmt.Sprintf("skip_org_role_sync = %v\n", s.skipOrgRoleSync)) - bf.WriteString(fmt.Sprintf("client_id = %v\n", s.Config.ClientID)) - bf.WriteString(fmt.Sprintf("client_secret = %v ; issue if empty\n", strings.Repeat("*", len(s.Config.ClientSecret)))) - bf.WriteString(fmt.Sprintf("auth_url = %v\n", s.Config.Endpoint.AuthURL)) - bf.WriteString(fmt.Sprintf("token_url = %v\n", s.Config.Endpoint.TokenURL)) - bf.WriteString(fmt.Sprintf("auth_style = %v\n", s.Config.Endpoint.AuthStyle)) - bf.WriteString(fmt.Sprintf("redirect_url = %v\n", s.Config.RedirectURL)) - bf.WriteString(fmt.Sprintf("scopes = %v\n", s.Config.Scopes)) - bf.WriteString("```\n\n") - return nil -} - -func (s *SocialBase) extractRoleAndAdminOptional(rawJSON []byte, groups []string) (org.RoleType, bool, error) { - if s.roleAttributePath == "" { - if s.roleAttributeStrict { - return "", false, errRoleAttributePathNotSet.Errorf("role_attribute_path not set and role_attribute_strict is set") - } - return "", false, nil - } - - if role, gAdmin := s.searchRole(rawJSON, groups); role.IsValid() { - return role, gAdmin, nil - } else if role != "" { - return "", false, errInvalidRole.Errorf("invalid role: %s", role) - } - - if s.roleAttributeStrict { - return "", false, errRoleAttributeStrictViolation.Errorf("idP did not return a role attribute, but role_attribute_strict is set") - } - - return "", false, nil -} - -func (s *SocialBase) extractRoleAndAdmin(rawJSON []byte, groups []string) (org.RoleType, bool, error) { - role, gAdmin, err := s.extractRoleAndAdminOptional(rawJSON, groups) - if role == "" { - role = s.defaultRole() - } - - return role, gAdmin, err -} - -func (s *SocialBase) searchRole(rawJSON []byte, groups []string) (org.RoleType, bool) { - role, err := s.searchJSONForStringAttr(s.roleAttributePath, rawJSON) - if err == nil && role != "" { - return getRoleFromSearch(role) - } - - if groupBytes, err := json.Marshal(groupStruct{groups}); err == nil { - role, err := s.searchJSONForStringAttr(s.roleAttributePath, groupBytes) - if err == nil && role != "" { - return getRoleFromSearch(role) - } - } - - return "", false -} - -// defaultRole returns the default role for the user based on the autoAssignOrgRole setting -// if legacy is enabled "" is returned indicating the previous role assignment is used. -func (s *SocialBase) defaultRole() org.RoleType { - if s.autoAssignOrgRole != "" { - s.log.Debug("No role found, returning default.") - return org.RoleType(s.autoAssignOrgRole) - } - - // should never happen - return org.RoleViewer -} - -// match grafana admin role and translate to org role and bool. -// treat the JSON search result to ensure correct casing. -func getRoleFromSearch(role string) (org.RoleType, bool) { - if strings.EqualFold(role, RoleGrafanaAdmin) { - return org.RoleAdmin, true - } - - return org.RoleType(cases.Title(language.Und).String(role)), false -} - -// GetOAuthProviders returns available oauth providers and if they're enabled or not -func (ss *SocialService) GetOAuthProviders() map[string]bool { - result := map[string]bool{} - - for name, conn := range ss.socialMap { - result[name] = conn.GetOAuthInfo().Enabled - } - - return result -} - -func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) { - // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does - name = strings.TrimPrefix(name, "oauth_") - provider, ok := ss.socialMap[name] - if !ok { - return nil, fmt.Errorf("could not find %q in OAuth Settings", name) - } - - info := provider.GetOAuthInfo() - if !info.Enabled { - return nil, fmt.Errorf("oauth provider %q is not enabled", name) - } - - // handle call back - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: info.TlsSkipVerify, - }, - DialContext: (&net.Dialer{ - Timeout: time.Second * 10, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 15 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - } - - oauthClient := &http.Client{ - Transport: tr, - Timeout: time.Second * 15, - } - - if info.TlsClientCert != "" || info.TlsClientKey != "" { - cert, err := tls.LoadX509KeyPair(info.TlsClientCert, info.TlsClientKey) - if err != nil { - ss.log.Error("Failed to setup TlsClientCert", "oauth", name, "error", err) - return nil, fmt.Errorf("failed to setup TlsClientCert: %w", err) - } - - tr.TLSClientConfig.Certificates = append(tr.TLSClientConfig.Certificates, cert) - } - - if info.TlsClientCa != "" { - caCert, err := os.ReadFile(info.TlsClientCa) - if err != nil { - ss.log.Error("Failed to setup TlsClientCa", "oauth", name, "error", err) - return nil, fmt.Errorf("failed to setup TlsClientCa: %w", err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - tr.TLSClientConfig.RootCAs = caCertPool - } - return oauthClient, nil -} - -func (ss *SocialService) GetConnector(name string) (SocialConnector, error) { - // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does - provider := strings.TrimPrefix(name, "oauth_") - connector, ok := ss.socialMap[provider] - if !ok { - return nil, fmt.Errorf("failed to find oauth provider for %q", name) - } - return connector, nil -} - -func (ss *SocialService) GetOAuthInfoProvider(name string) *OAuthInfo { - connector, ok := ss.socialMap[name] - if !ok { - return nil - } - return connector.GetOAuthInfo() -} - -// GetOAuthInfoProviders returns enabled OAuth providers -func (ss *SocialService) GetOAuthInfoProviders() map[string]*OAuthInfo { - result := map[string]*OAuthInfo{} - for name, connector := range ss.socialMap { - info := connector.GetOAuthInfo() - if info.Enabled { - result[name] = info - } - } - return result -} - -func (ss *SocialService) getUsageStats(ctx context.Context) (map[string]any, error) { - m := map[string]any{} - - authTypes := map[string]bool{} - for provider, enabled := range ss.GetOAuthProviders() { - authTypes["oauth_"+provider] = enabled - } - - for authType, enabled := range authTypes { - enabledValue := 0 - if enabled { - enabledValue = 1 - } - - m["stats.auth_enabled."+authType+".count"] = enabledValue - } - - return m, nil -} - -func (s *SocialBase) isGroupMember(groups []string) bool { - if len(s.allowedGroups) == 0 { - return true - } - - for _, allowedGroup := range s.allowedGroups { - for _, group := range groups { - if group == allowedGroup { - return true - } - } - } - - return false -} - -func (s *SocialBase) retrieveRawIDToken(idToken any) ([]byte, error) { - tokenString, ok := idToken.(string) - if !ok { - return nil, fmt.Errorf("id_token is not a string: %v", idToken) - } - - jwtRegexp := regexp.MustCompile("^([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)$") - matched := jwtRegexp.FindStringSubmatch(tokenString) - if matched == nil { - return nil, fmt.Errorf("id_token is not in JWT format: %s", tokenString) - } - - rawJSON, err := base64.RawURLEncoding.DecodeString(matched[2]) - if err != nil { - return nil, fmt.Errorf("error base64 decoding id_token: %w", err) - } - - headerBytes, err := base64.RawURLEncoding.DecodeString(matched[1]) - if err != nil { - return nil, fmt.Errorf("error base64 decoding header: %w", err) - } - - var header map[string]any - if err := json.Unmarshal(headerBytes, &header); err != nil { - return nil, fmt.Errorf("error deserializing header: %w", err) - } - - if compressionVal, exists := header["zip"]; exists { - compression, ok := compressionVal.(string) - if !ok { - return nil, fmt.Errorf("unrecognized compression header: %v", compressionVal) - } - - if compression != "DEF" { - return nil, fmt.Errorf("unknown compression algorithm: %s", compression) - } - - fr, err := zlib.NewReader(bytes.NewReader(rawJSON)) - if err != nil { - return nil, fmt.Errorf("error creating zlib reader: %w", err) - } - defer func() { - if err := fr.Close(); err != nil { - s.log.Warn("Failed closing zlib reader", "error", err) - } - }() - - rawJSON, err = io.ReadAll(fr) - if err != nil { - return nil, fmt.Errorf("error decompressing payload: %w", err) - } - } - - return rawJSON, nil -} - -func (ss *SocialService) createOAuthConnector(name string, settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (SocialConnector, error) { - switch name { - case AzureADProviderName: - return NewAzureADProvider(settings, cfg, features, cache) - case GenericOAuthProviderName: - return NewGenericOAuthProvider(settings, cfg, features) - case GitHubProviderName: - return NewGitHubProvider(settings, cfg, features) - case GitlabProviderName: - return NewGitLabProvider(settings, cfg, features) - case GoogleProviderName: - return NewGoogleProvider(settings, cfg, features) - case GrafanaComProviderName: - return NewGrafanaComProvider(settings, cfg, features) - case OktaProviderName: - return NewOktaProvider(settings, cfg, features) - default: - return nil, fmt.Errorf("unknown oauth provider: %s", name) - } -} - -func appendUniqueScope(config *oauth2.Config, scope string) { - if !slices.Contains(config.Scopes, OfflineAccessScope) { - config.Scopes = append(config.Scopes, OfflineAccessScope) - } -} diff --git a/pkg/login/social/socialimpl/service.go b/pkg/login/social/socialimpl/service.go new file mode 100644 index 00000000000..312efe16e5b --- /dev/null +++ b/pkg/login/social/socialimpl/service.go @@ -0,0 +1,245 @@ +package socialimpl + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/http" + "os" + "strings" + "time" + + "gopkg.in/ini.v1" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/infra/usagestats" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/connectors" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/supportbundles" + "github.com/grafana/grafana/pkg/setting" +) + +var ( + allOauthes = []string{social.GitHubProviderName, social.GitlabProviderName, social.GoogleProviderName, social.GenericOAuthProviderName, social.GrafanaNetProviderName, + social.GrafanaComProviderName, social.AzureADProviderName, social.OktaProviderName} +) + +type SocialService struct { + cfg *setting.Cfg + + socialMap map[string]social.SocialConnector + log log.Logger +} + +func ProvideService(cfg *setting.Cfg, + features *featuremgmt.FeatureManager, + usageStats usagestats.Service, + bundleRegistry supportbundles.Service, + cache remotecache.CacheStorage, + ssoSettings ssosettings.Service, +) *SocialService { + ss := &SocialService{ + cfg: cfg, + socialMap: make(map[string]social.SocialConnector), + log: log.New("login.social"), + } + + usageStats.RegisterMetricsFunc(ss.getUsageStats) + + if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { + allSettings, err := ssoSettings.List(context.Background()) + if err != nil { + ss.log.Error("Failed to get SSO settings", "error", err) + } + + for _, ssoSetting := range allSettings { + conn, err := createOAuthConnector(ssoSetting.Provider, ssoSetting.OAuthSettings, cfg, ssoSettings, features, cache) + if err != nil { + ss.log.Error("Failed to create OAuth provider", "error", err, "provider", ssoSetting.Provider) + continue + } + + ss.socialMap[ssoSetting.Provider] = conn + } + } else { + for _, name := range allOauthes { + sec := cfg.Raw.Section("auth." + name) + + settingsKVs := convertIniSectionToMap(sec) + info, err := connectors.CreateOAuthInfoFromKeyValues(settingsKVs) + if err != nil { + ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", name) + continue + } + + if !info.Enabled { + continue + } + + if name == social.GrafanaNetProviderName { + name = social.GrafanaComProviderName + } + + conn, _ := createOAuthConnector(name, info, cfg, ssoSettings, features, cache) + + ss.socialMap[name] = conn + } + } + + ss.registerSupportBundleCollectors(bundleRegistry) + + return ss +} + +// GetOAuthProviders returns available oauth providers and if they're enabled or not +func (ss *SocialService) GetOAuthProviders() map[string]bool { + result := map[string]bool{} + + for name, conn := range ss.socialMap { + result[name] = conn.GetOAuthInfo().Enabled + } + + return result +} + +func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) { + // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does + name = strings.TrimPrefix(name, "oauth_") + provider, ok := ss.socialMap[name] + if !ok { + return nil, fmt.Errorf("could not find %q in OAuth Settings", name) + } + + info := provider.GetOAuthInfo() + if !info.Enabled { + return nil, fmt.Errorf("oauth provider %q is not enabled", name) + } + + // handle call back + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: info.TlsSkipVerify, + }, + DialContext: (&net.Dialer{ + Timeout: time.Second * 10, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 15 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + } + + oauthClient := &http.Client{ + Transport: tr, + Timeout: time.Second * 15, + } + + if info.TlsClientCert != "" || info.TlsClientKey != "" { + cert, err := tls.LoadX509KeyPair(info.TlsClientCert, info.TlsClientKey) + if err != nil { + ss.log.Error("Failed to setup TlsClientCert", "oauth", name, "error", err) + return nil, fmt.Errorf("failed to setup TlsClientCert: %w", err) + } + + tr.TLSClientConfig.Certificates = append(tr.TLSClientConfig.Certificates, cert) + } + + if info.TlsClientCa != "" { + caCert, err := os.ReadFile(info.TlsClientCa) + if err != nil { + ss.log.Error("Failed to setup TlsClientCa", "oauth", name, "error", err) + return nil, fmt.Errorf("failed to setup TlsClientCa: %w", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tr.TLSClientConfig.RootCAs = caCertPool + } + return oauthClient, nil +} + +func (ss *SocialService) GetConnector(name string) (social.SocialConnector, error) { + // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does + provider := strings.TrimPrefix(name, "oauth_") + connector, ok := ss.socialMap[provider] + if !ok { + return nil, fmt.Errorf("failed to find oauth provider for %q", name) + } + return connector, nil +} + +func (ss *SocialService) GetOAuthInfoProvider(name string) *social.OAuthInfo { + connector, ok := ss.socialMap[name] + if !ok { + return nil + } + return connector.GetOAuthInfo() +} + +// GetOAuthInfoProviders returns enabled OAuth providers +func (ss *SocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo { + result := map[string]*social.OAuthInfo{} + for name, connector := range ss.socialMap { + info := connector.GetOAuthInfo() + if info.Enabled { + result[name] = info + } + } + return result +} + +func (ss *SocialService) getUsageStats(ctx context.Context) (map[string]any, error) { + m := map[string]any{} + + authTypes := map[string]bool{} + for provider, enabled := range ss.GetOAuthProviders() { + authTypes["oauth_"+provider] = enabled + } + + for authType, enabled := range authTypes { + enabledValue := 0 + if enabled { + enabledValue = 1 + } + + m["stats.auth_enabled."+authType+".count"] = enabledValue + } + + return m, nil +} + +func createOAuthConnector(name string, info *social.OAuthInfo, cfg *setting.Cfg, ssoSettings ssosettings.Service, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (social.SocialConnector, error) { + switch name { + case social.AzureADProviderName: + return connectors.NewAzureADProvider(info, cfg, ssoSettings, features, cache), nil + case social.GenericOAuthProviderName: + return connectors.NewGenericOAuthProvider(info, cfg, ssoSettings, features), nil + case social.GitHubProviderName: + return connectors.NewGitHubProvider(info, cfg, ssoSettings, features), nil + case social.GitlabProviderName: + return connectors.NewGitLabProvider(info, cfg, ssoSettings, features), nil + case social.GoogleProviderName: + return connectors.NewGoogleProvider(info, cfg, ssoSettings, features), nil + case social.GrafanaComProviderName: + return connectors.NewGrafanaComProvider(info, cfg, ssoSettings, features), nil + case social.OktaProviderName: + return connectors.NewOktaProvider(info, cfg, ssoSettings, features), nil + default: + return nil, fmt.Errorf("unknown oauth provider: %s", name) + } +} + +// convertIniSectionToMap converts key value pairs from an ini section to a map[string]any +func convertIniSectionToMap(sec *ini.Section) map[string]any { + mappedSettings := make(map[string]any) + for k, v := range sec.KeysHash() { + mappedSettings[k] = v + } + return mappedSettings +} diff --git a/pkg/login/social/commont_test.go b/pkg/login/social/socialimpl/service_test.go similarity index 92% rename from pkg/login/social/commont_test.go rename to pkg/login/social/socialimpl/service_test.go index c3728fd872e..2fd4ad57ad6 100644 --- a/pkg/login/social/commont_test.go +++ b/pkg/login/social/socialimpl/service_test.go @@ -1,10 +1,13 @@ -package social +package socialimpl import ( "testing" - "github.com/stretchr/testify/require" "gopkg.in/ini.v1" + + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/connectors" + "github.com/stretchr/testify/require" ) func TestMapping_IniSectionOAuthInfo(t *testing.T) { @@ -53,7 +56,7 @@ signout_redirect_url = https://oauth.com/signout?post_logout_redirect_uri=https: iniFile, err := ini.Load([]byte(iniContent)) require.NoError(t, err) - expectedOAuthInfo := &OAuthInfo{ + expectedOAuthInfo := &social.OAuthInfo{ Name: "OAuth", Icon: "signin", Enabled: true, @@ -96,7 +99,7 @@ signout_redirect_url = https://oauth.com/signout?post_logout_redirect_uri=https: } settingsKVs := convertIniSectionToMap(iniFile.Section("test")) - oauthInfo, err := CreateOAuthInfoFromKeyValues(settingsKVs) + oauthInfo, err := connectors.CreateOAuthInfoFromKeyValues(settingsKVs) require.NoError(t, err) require.Equal(t, expectedOAuthInfo, oauthInfo) diff --git a/pkg/login/social/support_bundle.go b/pkg/login/social/socialimpl/support_bundle.go similarity index 92% rename from pkg/login/social/support_bundle.go rename to pkg/login/social/socialimpl/support_bundle.go index 54dafcb8b26..09866a30b3e 100644 --- a/pkg/login/social/support_bundle.go +++ b/pkg/login/social/socialimpl/support_bundle.go @@ -1,4 +1,4 @@ -package social +package socialimpl import ( "bytes" @@ -9,6 +9,7 @@ import ( "github.com/BurntSushi/toml" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/supportbundles" ) @@ -26,7 +27,7 @@ func (ss *SocialService) registerSupportBundleCollectors(bundleRegistry supportb } } -func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnector) func(context.Context) (*supportbundles.SupportItem, error) { +func (ss *SocialService) supportBundleCollectorFn(name string, sc social.SocialConnector) func(context.Context) (*supportbundles.SupportItem, error) { return func(ctx context.Context) (*supportbundles.SupportItem, error) { bWriter := bytes.NewBuffer(nil) @@ -61,7 +62,7 @@ func (ss *SocialService) supportBundleCollectorFn(name string, sc SocialConnecto } } -func (ss *SocialService) healthCheckSocialConnector(ctx context.Context, name string, oinfo *OAuthInfo, bWriter *bytes.Buffer) { +func (ss *SocialService) healthCheckSocialConnector(ctx context.Context, name string, oinfo *social.OAuthInfo, bWriter *bytes.Buffer) { bWriter.WriteString("## Health checks\n\n") client, err := ss.GetOAuthHttpClient(name) if err != nil { diff --git a/pkg/login/socialtest/social_connector_mock.go b/pkg/login/social/socialtest/social_connector_mock.go similarity index 97% rename from pkg/login/socialtest/social_connector_mock.go rename to pkg/login/social/socialtest/social_connector_mock.go index abac2e322a6..17ed902bd01 100644 --- a/pkg/login/socialtest/social_connector_mock.go +++ b/pkg/login/social/socialtest/social_connector_mock.go @@ -22,11 +22,11 @@ type MockSocialConnector struct { // AuthCodeURL provides a mock function with given fields: state, opts func (_m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - _va := make([]any, len(opts)) + _va := make([]interface{}, len(opts)) for _i := range opts { _va[_i] = opts[_i] } - var _ca []any + var _ca []interface{} _ca = append(_ca, state) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -59,11 +59,11 @@ func (_m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *htt // Exchange provides a mock function with given fields: ctx, code, authOptions func (_m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - _va := make([]any, len(authOptions)) + _va := make([]interface{}, len(authOptions)) for _i := range authOptions { _va[_i] = authOptions[_i] } - var _ca []any + var _ca []interface{} _ca = append(_ca, ctx, code) _ca = append(_ca, _va...) ret := _m.Called(_ca...) diff --git a/pkg/login/socialtest/social_service_fake.go b/pkg/login/social/socialtest/social_service_fake.go similarity index 100% rename from pkg/login/socialtest/social_service_fake.go rename to pkg/login/social/socialtest/social_service_fake.go diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 23914a8b3e7..64909565597 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -31,6 +31,7 @@ import ( "github.com/grafana/grafana/pkg/infra/usagestats/statscollector" "github.com/grafana/grafana/pkg/infra/usagestats/validator" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/socialimpl" "github.com/grafana/grafana/pkg/middleware/csrf" "github.com/grafana/grafana/pkg/middleware/loggermw" apiregistry "github.com/grafana/grafana/pkg/registry/apis" @@ -256,9 +257,9 @@ var wireBasicSet = wire.NewSet( testdatasource.ProvideService, ldapapi.ProvideService, opentsdb.ProvideService, - social.ProvideService, + socialimpl.ProvideService, influxdb.ProvideService, - wire.Bind(new(social.Service), new(*social.SocialService)), + wire.Bind(new(social.Service), new(*socialimpl.SocialService)), tempo.ProvideService, loki.ProvideService, graphite.ProvideService, diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go index 8c494c8525f..3632deb5633 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go @@ -16,7 +16,7 @@ import ( "github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/login/socialtest" + "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/auth/identity" diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index e2b39229478..55fcca99c63 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -15,6 +15,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/connectors" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/org" @@ -49,7 +50,7 @@ var ( errOAuthEmailNotAllowed = errutil.Unauthorized("auth.oauth.email.not-allowed", errutil.WithPublicMessage("Required email domain not fulfilled")) ) -func fromSocialErr(err *social.Error) error { +func fromSocialErr(err *connectors.SocialError) error { return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err) } @@ -118,7 +119,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token) if err != nil { - var sErr *social.Error + var sErr *connectors.SocialError if errors.As(err, &sErr) { return nil, fromSocialErr(sErr) } diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index 0103748a286..28e56398ac6 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -13,7 +13,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/sync/singleflight" - "github.com/grafana/grafana/pkg/login/socialtest" + "github.com/grafana/grafana/pkg/login/social/socialtest" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login/authinfoimpl" "github.com/grafana/grafana/pkg/services/user" diff --git a/pkg/services/ssosettings/api/api.go b/pkg/services/ssosettings/api/api.go index e6840ed471a..b92eadd1230 100644 --- a/pkg/services/ssosettings/api/api.go +++ b/pkg/services/ssosettings/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "net/http" @@ -8,6 +9,7 @@ import ( "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/infra/log" ac "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/auth/identity" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/ssosettings" @@ -56,7 +58,7 @@ func (api *Api) RegisterAPIEndpoints() { } func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Response { - providers, err := api.SSOSettingsService.List(c.Req.Context(), c.SignedInUser) + providers, err := api.getAuthorizedList(c.Req.Context(), c.SignedInUser) if err != nil { return response.Error(500, "Failed to get providers", err) } @@ -75,6 +77,31 @@ func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Re return response.JSON(http.StatusOK, dtos) } +func (api *Api) getAuthorizedList(ctx context.Context, identity identity.Requester) ([]*models.SSOSettings, error) { + allProviders, err := api.SSOSettingsService.List(ctx) + if err != nil { + return nil, err + } + + var authorizedProviders []*models.SSOSettings + for _, provider := range allProviders { + ev := ac.EvalPermission(ac.ActionSettingsRead, ac.Scope("settings", "auth."+provider.Provider, "*")) + hasAccess, err := api.AccessControl.Evaluate(ctx, identity, ev) + if err != nil { + api.Log.FromContext(ctx).Error("Failed to evaluate permissions", "error", err) + return nil, err + } + + if !hasAccess { + continue + } + + authorizedProviders = append(authorizedProviders, provider) + } + + return authorizedProviders, nil +} + func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Response { key, ok := web.Params(c.Req)[":key"] if !ok { diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go index 5899353525b..023d1f27939 100644 --- a/pkg/services/ssosettings/ssosettings.go +++ b/pkg/services/ssosettings/ssosettings.go @@ -4,7 +4,6 @@ import ( "context" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/ssosettings/models" ) @@ -21,7 +20,7 @@ var ( //go:generate mockery --name Service --structname MockService --outpkg ssosettingstests --filename service_mock.go --output ./ssosettingstests/ type Service interface { // List returns all SSO settings from DB and config files - List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) + List(ctx context.Context) ([]*models.SSOSettings, error) // GetForProvider returns the SSO settings for a given provider (DB or config file) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) // Upsert creates or updates the SSO settings for a given provider @@ -30,15 +29,16 @@ type Service interface { Delete(ctx context.Context, provider string) error // Patch updates the specified SSO settings (key-value pairs) for a given provider Patch(ctx context.Context, provider string, data map[string]any) error - // RegisterReloadable registers a reloadable provider - RegisterReloadable(ctx context.Context, provider string, reloadable Reloadable) - // Reload implements ssosettings.Reloadable interface + // RegisterReloadable registers a reloadable for a given provider + RegisterReloadable(provider string, reloadable Reloadable) + // Reload reloads the settings for a given provider Reload(ctx context.Context, provider string) } -// Reloadable is an interface that can be implemented by a provider to allow it to be reloaded +// Reloadable is an interface that can be implemented by a provider to allow it to be validated and reloaded type Reloadable interface { - Reload(ctx context.Context) error + Reload(ctx context.Context, settings models.SSOSettings) error + Validate(ctx context.Context, settings models.SSOSettings) error } // FallbackStrategy is an interface that can be implemented to allow a provider to load settings from a different source diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go index 0dc92704efc..c51f0b7b13f 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -9,7 +9,6 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" ac "github.com/grafana/grafana/pkg/services/accesscontrol" - "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/ssosettings" @@ -23,12 +22,14 @@ import ( var _ ssosettings.Service = (*SSOSettingsService)(nil) type SSOSettingsService struct { - log log.Logger - cfg *setting.Cfg - store ssosettings.Store - ac ac.AccessControl + log log.Logger + cfg *setting.Cfg + store ssosettings.Store + ac ac.AccessControl + secrets secrets.Service + fbStrategies []ssosettings.FallbackStrategy - secrets secrets.Service + reloadables map[string]ssosettings.Reloadable } func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl, @@ -48,6 +49,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl, ac: ac, fbStrategies: strategies, secrets: secrets, + reloadables: make(map[string]ssosettings.Reloadable), } if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) { @@ -81,7 +83,7 @@ func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string return storeSettings, nil } -func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) { +func (s *SSOSettingsService) List(ctx context.Context) ([]*models.SSOSettings, error) { result := make([]*models.SSOSettings, 0, len(ssosettings.AllOAuthProviders)) storedSettings, err := s.store.List(ctx) @@ -90,25 +92,15 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques } for _, provider := range ssosettings.AllOAuthProviders { - ev := ac.EvalPermission(ac.ActionSettingsRead, ac.Scope("settings", "auth."+provider, "*")) - hasAccess, err := s.ac.Evaluate(ctx, requester, ev) - if err != nil { - return nil, err - } - - if !hasAccess { - continue - } - settings := getSettingsByProvider(provider, storedSettings) if len(settings) == 0 { // If there is no data in the DB then we need to load the settings using the fallback strategy - fallbackSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) if err != nil { return nil, err } - settings = append(settings, fallbackSettings) + settings = append(settings, setting) } result = append(result, settings...) } @@ -117,7 +109,8 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques } func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error { - // TODO: validation (configurable provider? Contains the required fields? etc) + // TODO: also check whether the provider is configurable + // Get the connector for the provider (from the reloadables) and call Validate if isOAuthProvider(settings.Provider) { encryptedClientSecret, err := s.secrets.Encrypt(ctx, []byte(settings.OAuthSettings.ClientSecret), secrets.WithoutScope()) @@ -131,6 +124,7 @@ func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSett if err != nil { return err } + return nil } @@ -146,8 +140,11 @@ func (s *SSOSettingsService) Reload(ctx context.Context, provider string) { panic("not implemented") // TODO: Implement } -func (s *SSOSettingsService) RegisterReloadable(ctx context.Context, provider string, reloadable ssosettings.Reloadable) { - panic("not implemented") // TODO: Implement +func (s *SSOSettingsService) RegisterReloadable(provider string, reloadable ssosettings.Reloadable) { + if s.reloadables == nil { + s.reloadables = make(map[string]ssosettings.Reloadable) + } + s.reloadables[provider] = reloadable } func (s *SSOSettingsService) RegisterFallbackStrategy(providerRegex string, strategy ssosettings.FallbackStrategy) { diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go index d75191729ba..d909bfb6658 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service_test.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -13,12 +13,10 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" - "github.com/grafana/grafana/pkg/services/auth/identity" secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" - "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -106,34 +104,11 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) { } func TestSSOSettingsService_List(t *testing.T) { - defaultIdentity := &user.SignedInUser{ - UserID: 1, - OrgID: 1, - Permissions: map[int64]map[string][]string{ - 1: { - accesscontrol.ActionSettingsRead: {accesscontrol.ScopeSettingsAll}, - }, - }, - } - - scopedIdentity := &user.SignedInUser{ - UserID: 1, - OrgID: 1, - Permissions: map[int64]map[string][]string{ - 1: { - accesscontrol.ActionSettingsRead: []string{ - accesscontrol.Scope("settings", "auth.azuread", "*"), - accesscontrol.Scope("settings", "auth.github", "*"), - }, - }, - }, - } testCases := []struct { - name string - setup func(env testEnv) - identity identity.Requester - want []*models.SSOSettings - wantErr bool + name string + setup func(env testEnv) + want []*models.SSOSettings + wantErr bool }{ { name: "should return successfully", @@ -153,7 +128,6 @@ func TestSSOSettingsService_List(t *testing.T) { env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: false} }, - identity: defaultIdentity, want: []*models.SSOSettings{ { Provider: "github", @@ -194,44 +168,10 @@ func TestSSOSettingsService_List(t *testing.T) { wantErr: false, }, { - name: "should return the settings that the user has access to", - setup: func(env testEnv) { - env.store.ExpectedSSOSettings = []*models.SSOSettings{ - { - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, - }, - { - Provider: "okta", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, - }, - } - env.fallbackStrategy.ExpectedIsMatch = true - env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: false} - }, - identity: scopedIdentity, - want: []*models.SSOSettings{ - { - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, - }, - { - Provider: "azuread", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, - }, - }, - wantErr: false, - }, - { - name: "should return error if store returns an error", - setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, - identity: defaultIdentity, - want: nil, - wantErr: true, + name: "should return error if store returns an error", + setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, + want: nil, + wantErr: true, }, { name: "should use the fallback strategy if store returns empty list", @@ -240,7 +180,6 @@ func TestSSOSettingsService_List(t *testing.T) { env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: false} }, - identity: defaultIdentity, want: []*models.SSOSettings{ { Provider: "github", @@ -286,9 +225,8 @@ func TestSSOSettingsService_List(t *testing.T) { env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = false }, - identity: defaultIdentity, - want: nil, - wantErr: true, + want: nil, + wantErr: true, }, } for _, tc := range testCases { @@ -298,7 +236,7 @@ func TestSSOSettingsService_List(t *testing.T) { tc.setup(env) } - actual, err := env.service.List(context.Background(), tc.identity) + actual, err := env.service.List(context.Background()) if tc.wantErr { require.Error(t, err) @@ -416,6 +354,7 @@ func setupTestEnv(t *testing.T) testEnv { store: store, ac: accessControl, fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy}, + reloadables: make(map[string]ssosettings.Reloadable), secrets: secrets, } diff --git a/pkg/services/ssosettings/ssosettingstests/service_mock.go b/pkg/services/ssosettings/ssosettingstests/service_mock.go index 6cb3edbf90a..31efe79018e 100644 --- a/pkg/services/ssosettings/ssosettingstests/service_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/service_mock.go @@ -1,14 +1,12 @@ -// Code generated by mockery v2.37.1. DO NOT EDIT. +// Code generated by mockery v2.27.1. DO NOT EDIT. package ssosettingstests import ( context "context" - identity "github.com/grafana/grafana/pkg/services/auth/identity" - mock "github.com/stretchr/testify/mock" - models "github.com/grafana/grafana/pkg/services/ssosettings/models" + mock "github.com/stretchr/testify/mock" ssosettings "github.com/grafana/grafana/pkg/services/ssosettings" ) @@ -58,25 +56,25 @@ func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*mo return r0, r1 } -// List provides a mock function with given fields: ctx, requester -func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) { - ret := _m.Called(ctx, requester) +// List provides a mock function with given fields: ctx +func (_m *MockService) List(ctx context.Context) ([]*models.SSOSettings, error) { + ret := _m.Called(ctx) var r0 []*models.SSOSettings var r1 error - if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSettings, error)); ok { - return rf(ctx, requester) + if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSettings, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSettings); ok { - r0 = rf(ctx, requester) + if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSettings); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SSOSettings) } } - if rf, ok := ret.Get(1).(func(context.Context, identity.Requester) error); ok { - r1 = rf(ctx, requester) + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -98,9 +96,9 @@ func (_m *MockService) Patch(ctx context.Context, provider string, data map[stri return r0 } -// RegisterReloadable provides a mock function with given fields: ctx, provider, reloadable -func (_m *MockService) RegisterReloadable(ctx context.Context, provider string, reloadable ssosettings.Reloadable) { - _m.Called(ctx, provider, reloadable) +// RegisterReloadable provides a mock function with given fields: provider, reloadable +func (_m *MockService) RegisterReloadable(provider string, reloadable ssosettings.Reloadable) { + _m.Called(provider, reloadable) } // Reload provides a mock function with given fields: ctx, provider @@ -122,12 +120,13 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) return r0 } -// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockService(t interface { +type mockConstructorTestingTNewMockService interface { mock.TestingT Cleanup(func()) -}) *MockService { +} + +// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockService(t mockConstructorTestingTNewMockService) *MockService { mock := &MockService{} mock.Mock.Test(t) diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go index 55214c18fe8..17d3adbe12b 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.37.1. DO NOT EDIT. +// Code generated by mockery v2.27.1. DO NOT EDIT. package ssosettingstests @@ -108,12 +108,13 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er return r0 } -// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockStore(t interface { +type mockConstructorTestingTNewMockStore interface { mock.TestingT Cleanup(func()) -}) *MockStore { +} + +// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore { mock := &MockStore{} mock.Mock.Test(t) diff --git a/pkg/services/ssosettings/strategies/oauth_strategy.go b/pkg/services/ssosettings/strategies/oauth_strategy.go index ae14b1f32eb..3f43b6625dc 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy.go @@ -4,6 +4,7 @@ import ( "context" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/login/social/connectors" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" @@ -15,11 +16,11 @@ type OAuthStrategy struct { } var extraKeysByProvider = map[string][]string{ - social.AzureADProviderName: social.ExtraAzureADSettingKeys, - social.GenericOAuthProviderName: social.ExtraGenericOAuthSettingKeys, - social.GitHubProviderName: social.ExtraGithubSettingKeys, - social.GrafanaComProviderName: social.ExtraGrafanaComSettingKeys, - social.GrafanaNetProviderName: social.ExtraGrafanaComSettingKeys, + social.AzureADProviderName: connectors.ExtraAzureADSettingKeys, + social.GenericOAuthProviderName: connectors.ExtraGenericOAuthSettingKeys, + social.GitHubProviderName: connectors.ExtraGithubSettingKeys, + social.GrafanaComProviderName: connectors.ExtraGrafanaComSettingKeys, + social.GrafanaNetProviderName: connectors.ExtraGrafanaComSettingKeys, } var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil)