diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 9ffe00933a4..0b0ef84c2e6 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -53,39 +53,39 @@ type SocialConnector interface { } type OAuthInfo struct { - AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin" json:"allowAssignGrafanaAdmin"` - AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up" json:"allowSignup"` - AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains" json:"allowedDomains"` - AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups" json:"allowedGroups"` - ApiUrl string `mapstructure:"api_url" toml:"api_url" json:"apiUrl"` - AuthStyle string `mapstructure:"auth_style" toml:"auth_style" json:"authStyle"` - AuthUrl string `mapstructure:"auth_url" toml:"auth_url" json:"authUrl"` - AutoLogin bool `mapstructure:"auto_login" toml:"auto_login" json:"autoLogin"` - ClientId string `mapstructure:"client_id" toml:"client_id" json:"clientId"` - ClientSecret string `mapstructure:"client_secret" toml:"-" json:"clientSecret"` - EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name" json:"emailAttributeName"` - EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path" json:"emailAttributePath"` - EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes" json:"emptyScopes"` - Enabled bool `mapstructure:"enabled" toml:"enabled" json:"enabled"` - GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path" json:"groupsAttributePath"` - HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain" json:"hostedDomain"` - Icon string `mapstructure:"icon" toml:"icon" json:"icon"` - Name string `mapstructure:"name" toml:"name" json:"name"` - RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path" json:"roleAttributePath"` - RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict" json:"roleAttributeStrict"` - Scopes []string `mapstructure:"scopes" toml:"scopes" json:"scopes"` - SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url" json:"signoutRedirectUrl"` - SkipOrgRoleSync bool `mapstructure:"skip_org_role_sync" toml:"skip_org_role_sync" json:"skipOrgRoleSync"` - TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path" json:"teamIdsAttributePath"` - TeamsUrl string `mapstructure:"teams_url" toml:"teams_url" json:"teamsUrl"` - TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca" json:"tlsClientCa"` - TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert" json:"tlsClientCert"` - TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key" json:"tlsClientKey"` - TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure" json:"tlsSkipVerify"` - TokenUrl string `mapstructure:"token_url" toml:"token_url" json:"tokenUrl"` - UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce" json:"usePKCE"` - UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token" json:"useRefreshToken"` - Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty" json:"extra"` + AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin"` + AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up"` + AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains"` + AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups"` + ApiUrl string `mapstructure:"api_url" toml:"api_url"` + AuthStyle string `mapstructure:"auth_style" toml:"auth_style"` + AuthUrl string `mapstructure:"auth_url" toml:"auth_url"` + AutoLogin bool `mapstructure:"auto_login" toml:"auto_login"` + ClientId string `mapstructure:"client_id" toml:"client_id"` + ClientSecret string `mapstructure:"client_secret" toml:"-"` + EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name"` + EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path"` + EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes"` + Enabled bool `mapstructure:"enabled" toml:"enabled"` + GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path"` + HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain"` + Icon string `mapstructure:"icon" toml:"icon"` + Name string `mapstructure:"name" toml:"name"` + RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path"` + RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict"` + Scopes []string `mapstructure:"scopes" toml:"scopes"` + SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url"` + SkipOrgRoleSync bool `mapstructure:"skip_org_role_sync" toml:"skip_org_role_sync"` + TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path"` + TeamsUrl string `mapstructure:"teams_url" toml:"teams_url"` + TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca"` + TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert"` + TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key"` + TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure"` + TokenUrl string `mapstructure:"token_url" toml:"token_url"` + UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce"` + UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token"` + Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty"` } func NewOAuthInfo() *OAuthInfo { diff --git a/pkg/login/social/socialimpl/service.go b/pkg/login/social/socialimpl/service.go index 5cf45747749..30893daca07 100644 --- a/pkg/login/social/socialimpl/service.go +++ b/pkg/login/social/socialimpl/service.go @@ -58,7 +58,13 @@ func ProvideService(cfg *setting.Cfg, } for _, ssoSetting := range allSettings { - conn, err := createOAuthConnector(ssoSetting.Provider, ssoSetting.OAuthSettings, cfg, ssoSettings, features, cache) + info, err := connectors.CreateOAuthInfoFromKeyValues(ssoSetting.Settings) + if err != nil { + ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", ssoSetting.Provider) + continue + } + + conn, err := createOAuthConnector(ssoSetting.Provider, info, cfg, ssoSettings, features, cache) if err != nil { ss.log.Error("Failed to create OAuth provider", "error", err, "provider", ssoSetting.Provider) continue diff --git a/pkg/services/ssosettings/api/api.go b/pkg/services/ssosettings/api/api.go index b92eadd1230..9abbee0dc45 100644 --- a/pkg/services/ssosettings/api/api.go +++ b/pkg/services/ssosettings/api/api.go @@ -60,21 +60,10 @@ func (api *Api) RegisterAPIEndpoints() { func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Response { providers, err := api.getAuthorizedList(c.Req.Context(), c.SignedInUser) if err != nil { - return response.Error(500, "Failed to get providers", err) + return response.Error(http.StatusInternalServerError, "Failed to get providers", err) } - dtos := make([]*models.SSOSettingsDTO, 0) - for _, provider := range providers { - dto, err := provider.ToSSOSettingsDTO() - if err != nil { - api.Log.Warn("Failed to convert SSO Settings for provider " + provider.Provider) - continue - } - - dtos = append(dtos, dto) - } - - return response.JSON(http.StatusOK, dtos) + return response.JSON(http.StatusOK, providers) } func (api *Api) getAuthorizedList(ctx context.Context, identity identity.Requester) ([]*models.SSOSettings, error) { @@ -113,12 +102,7 @@ func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Respons return response.Error(http.StatusNotFound, "The provider was not found", err) } - dto, err := settings.ToSSOSettingsDTO() - if err != nil { - return response.Error(http.StatusInternalServerError, "The provider is invalid", err) - } - - return response.JSON(http.StatusOK, dto) + return response.JSON(http.StatusOK, settings) } func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Response { @@ -127,19 +111,14 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp return response.Error(http.StatusBadRequest, "Missing key", nil) } - var settingsDTO models.SSOSettingsDTO - if err := web.Bind(c.Req, &settingsDTO); err != nil { + var settings models.SSOSettings + if err := web.Bind(c.Req, &settings); err != nil { return response.Error(http.StatusBadRequest, "Failed to parse request body", err) } - settings, err := settingsDTO.ToSSOSettings() - if err != nil { - return response.Error(http.StatusBadRequest, "Invalid request body", err) - } - settings.Provider = key - err = api.SSOSettingsService.Upsert(c.Req.Context(), *settings) + err := api.SSOSettingsService.Upsert(c.Req.Context(), settings) // TODO: first check whether the error is referring to validation errors // other error diff --git a/pkg/services/ssosettings/database/database.go b/pkg/services/ssosettings/database/database.go index 4d1da5ed6a4..081b607dde7 100644 --- a/pkg/services/ssosettings/database/database.go +++ b/pkg/services/ssosettings/database/database.go @@ -32,7 +32,7 @@ func ProvideStore(sqlStore db.DB) *SSOSettingsStore { var _ ssosettings.Store = (*SSOSettingsStore)(nil) func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) { - result := models.SSOSettingsDTO{Provider: provider} + result := models.SSOSettings{Provider: provider} err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { var err error sess.Table("sso_setting") @@ -53,19 +53,14 @@ func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SS return nil, err } - dto, err := result.ToSSOSettings() - if err != nil { - return nil, err - } - - return dto, nil + return &result, nil } func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, error) { - dtos := make([]*models.SSOSettingsDTO, 0) + result := make([]*models.SSOSettings, 0) err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { sess.Table("sso_setting") - err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&dtos) + err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result) if err != nil { return err @@ -78,29 +73,13 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err return nil, err } - settings := make([]*models.SSOSettings, 0) - for _, dto := range dtos { - item, err := dto.ToSSOSettings() - if err != nil { - s.log.Warn("Failed to convert DB settings to SSOSettings for provider " + dto.Provider) - continue - } - - settings = append(settings, item) - } - - return settings, nil + return result, nil } func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error { - dto, err := settings.ToSSOSettingsDTO() - if err != nil { - return err - } - return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - existing := &models.SSOSettingsDTO{ - Provider: dto.Provider, + existing := &models.SSOSettings{ + Provider: settings.Provider, IsDeleted: false, } found, err := sess.UseBool("is_deleted").Exist(existing) @@ -111,17 +90,17 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin now := timeNow().UTC() if found { - updated := &models.SSOSettingsDTO{ - Settings: dto.Settings, + updated := &models.SSOSettings{ + Settings: settings.Settings, Updated: now, IsDeleted: false, } _, err = sess.UseBool("is_deleted").Update(updated, existing) } else { - _, err = sess.Insert(&models.SSOSettingsDTO{ + _, err = sess.Insert(&models.SSOSettings{ ID: uuid.New().String(), - Provider: dto.Provider, - Settings: dto.Settings, + Provider: settings.Provider, + Settings: settings.Settings, Created: now, Updated: now, }) @@ -137,7 +116,7 @@ func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[ func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error { return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - existing := &models.SSOSettingsDTO{ + existing := &models.SSOSettings{ Provider: provider, IsDeleted: false, } diff --git a/pkg/services/ssosettings/database/database_test.go b/pkg/services/ssosettings/database/database_test.go index ee8902889a7..416db88892d 100644 --- a/pkg/services/ssosettings/database/database_test.go +++ b/pkg/services/ssosettings/database/database_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/db" - "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings/models" @@ -28,9 +27,7 @@ func TestIntegrationGetSSOSettings(t *testing.T) { ssoSettingsStore = ProvideStore(sqlStore) template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - }, + Settings: map[string]any{"enabled": true}, } err := populateSSOSettings(sqlStore, template, "azuread") require.NoError(t, err) @@ -40,14 +37,14 @@ func TestIntegrationGetSSOSettings(t *testing.T) { setup() expected := &models.SSOSettings{ - Provider: "azuread", - OAuthSettings: &social.OAuthInfo{Enabled: true}, + Provider: "azuread", + Settings: map[string]any{"enabled": true}, } actual, err := ssoSettingsStore.Get(context.Background(), "azuread") require.NoError(t, err) - require.Equal(t, expected.OAuthSettings, actual.OAuthSettings) + require.EqualValues(t, expected.Settings, actual.Settings) }) t.Run("returns not found if the SSO setting is missing for the specified provider", func(t *testing.T) { @@ -88,9 +85,9 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { settings := models.SSOSettings{ Provider: "azuread", - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "azuread-client", + Settings: map[string]any{ + "enabled": true, + "client_id": "azuread-client", }, } @@ -99,7 +96,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false) require.NoError(t, err) - require.EqualValues(t, settings.OAuthSettings, actual.OAuthSettings) + require.EqualValues(t, settings.Settings, actual.Settings) require.NotEmpty(t, actual.ID) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created)) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) @@ -118,10 +115,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { provider := "github" template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "github-client", - ClientSecret: "this-is-a-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "github-client", + "client_secret": "this-is-a-secret", }, } err := populateSSOSettings(sqlStore, template, provider) @@ -129,10 +126,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { newSettings := models.SSOSettings{ Provider: provider, - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "new-github-client", - ClientSecret: "this-is-a-new-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "new-github-client", + "client_secret": "this-is-a-new-secret", }, } err = ssoSettingsStore.Upsert(context.Background(), newSettings) @@ -140,7 +137,7 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { actual, err := getSSOSettingsByProvider(sqlStore, provider, false) require.NoError(t, err) - require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) + require.EqualValues(t, newSettings.Settings, actual.Settings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) deleted, notDeleted, err := getSSOSettingsCountByDeleted(sqlStore) @@ -157,10 +154,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { provider := "azuread" template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "azuread-client", - ClientSecret: "this-is-a-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "azuread-client", + "client_secret": "this-is-a-secret", }, IsDeleted: true, } @@ -169,10 +166,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { newSettings := models.SSOSettings{ Provider: provider, - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "new-azuread-client", - ClientSecret: "this-is-a-new-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "new-azuread-client", + "client_secret": "this-is-a-new-secret", }, } @@ -181,13 +178,13 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { actual, err := getSSOSettingsByProvider(sqlStore, provider, false) require.NoError(t, err) - require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) + require.EqualValues(t, newSettings.Settings, actual.Settings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created)) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) old, err := getSSOSettingsByProvider(sqlStore, provider, true) require.NoError(t, err) - require.Equal(t, template.OAuthSettings, old.OAuthSettings) + require.EqualValues(t, template.Settings, old.Settings) }) t.Run("replaces the settings only for the specified provider leaving the other provider's settings unchanged", func(t *testing.T) { @@ -198,10 +195,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { providers := []string{"github", "gitlab", "google"} template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "my-client", - ClientSecret: "this-is-a-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "my-client", + "client_secret": "this-is-a-secret", }, } err := populateSSOSettings(sqlStore, template, providers...) @@ -209,10 +206,10 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { newSettings := models.SSOSettings{ Provider: providers[0], - OAuthSettings: &social.OAuthInfo{ - Enabled: true, - ClientId: "my-new-client", - ClientSecret: "this-is-my-new-secret", + Settings: map[string]any{ + "enabled": true, + "client_id": "my-new-client", + "client_secret": "this-is-my-new-secret", }, } err = ssoSettingsStore.Upsert(context.Background(), newSettings) @@ -220,13 +217,13 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false) require.NoError(t, err) - require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) + require.EqualValues(t, newSettings.Settings, actual.Settings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) for index := 1; index < len(providers); index++ { existing, err := getSSOSettingsByProvider(sqlStore, providers[index], false) require.NoError(t, err) - require.EqualValues(t, template.OAuthSettings, existing.OAuthSettings) + require.EqualValues(t, template.Settings, existing.Settings) } }) } @@ -244,16 +241,16 @@ func TestIntegrationListSSOSettings(t *testing.T) { ssoSettingsStore = ProvideStore(sqlStore) template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, + Settings: map[string]any{ + "enabled": true, }, } err := populateSSOSettings(sqlStore, template, "azuread") require.NoError(t, err) template = models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: false, + Settings: map[string]any{ + "enabled": true, }, } err = populateSSOSettings(sqlStore, template, "okta") @@ -288,8 +285,8 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { providers := []string{"azuread", "github", "google"} template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, + Settings: map[string]any{ + "enabled": true, }, } err := populateSSOSettings(sqlStore, template, providers...) @@ -310,8 +307,8 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { providers := []string{"github", "google", "okta"} invalidProvider := "azuread" template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, + Settings: map[string]any{ + "enabled": true, }, } err := populateSSOSettings(sqlStore, template, providers...) @@ -332,8 +329,8 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { providers := []string{"azuread", "github", "google"} template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, + Settings: map[string]any{ + "enabled": true, }, IsDeleted: true, } @@ -355,8 +352,8 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { provider := "azuread" template := models.SSOSettings{ - OAuthSettings: &social.OAuthInfo{ - Enabled: true, + Settings: map[string]any{ + "enabled": true, }, } // insert sso for the same provider 2 times in the database @@ -378,16 +375,15 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { func populateSSOSettings(sqlStore *sqlstore.SQLStore, template models.SSOSettings, providers ...string) error { return sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { for _, provider := range providers { - template.Provider = provider - template.ID = uuid.New().String() - template.Created = timeNow().UTC() - - dto, err := template.ToSSOSettingsDTO() - if err != nil { - return err + settings := models.SSOSettings{ + ID: uuid.New().String(), + Provider: provider, + Settings: template.Settings, + Created: timeNow().UTC(), + IsDeleted: template.IsDeleted, } - _, err = sess.Insert(dto) + _, err := sess.Insert(settings) if err != nil { return err } @@ -410,7 +406,7 @@ func getSSOSettingsCountByDeleted(sqlStore *sqlstore.SQLStore) (deleted, notDele } func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSettings, error) { - var model models.SSOSettingsDTO + var model models.SSOSettings var err error err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { @@ -422,12 +418,7 @@ func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, dele return nil, err } - settings, err := model.ToSSOSettings() - if err != nil { - return nil, err - } - - return settings, err + return &model, err } func mockTimeNow(timeSeed time.Time) { diff --git a/pkg/services/ssosettings/models/models.go b/pkg/services/ssosettings/models/models.go index 719ef4b01bf..0f699a4bca7 100644 --- a/pkg/services/ssosettings/models/models.go +++ b/pkg/services/ssosettings/models/models.go @@ -4,8 +4,6 @@ import ( "encoding/json" "fmt" "time" - - "github.com/grafana/grafana/pkg/login/social" ) type SettingsSource int @@ -27,16 +25,6 @@ func (s SettingsSource) MarshalJSON() ([]byte, error) { } type SSOSettings struct { - ID string - Provider string - OAuthSettings *social.OAuthInfo - Created time.Time - Updated time.Time - IsDeleted bool - Source SettingsSource -} - -type SSOSettingsDTO struct { ID string `xorm:"id pk" json:"id"` Provider string `xorm:"provider" json:"provider"` Settings map[string]any `xorm:"settings" json:"settings"` @@ -47,51 +35,8 @@ type SSOSettingsDTO struct { } // TableName returns the table name (needed for Xorm) -func (s SSOSettingsDTO) TableName() string { +func (s SSOSettings) TableName() string { return "sso_setting" } -func (s SSOSettingsDTO) ToSSOSettings() (*SSOSettings, error) { - settingsEncoded, err := json.Marshal(s.Settings) - if err != nil { - return nil, err - } - - var settings social.OAuthInfo - err = json.Unmarshal(settingsEncoded, &settings) - if err != nil { - return nil, err - } - - return &SSOSettings{ - ID: s.ID, - Provider: s.Provider, - OAuthSettings: &settings, - Created: s.Created, - Updated: s.Updated, - IsDeleted: s.IsDeleted, - }, nil -} - -func (s SSOSettings) ToSSOSettingsDTO() (*SSOSettingsDTO, error) { - settingsEncoded, err := json.Marshal(s.OAuthSettings) - if err != nil { - return nil, err - } - - var settings map[string]any - err = json.Unmarshal(settingsEncoded, &settings) - if err != nil { - return nil, err - } - - return &SSOSettingsDTO{ - ID: s.ID, - Provider: s.Provider, - Settings: settings, - Created: s.Created, - Updated: s.Updated, - IsDeleted: s.IsDeleted, - Source: s.Source, - }, nil -} +// TODO: check if we need custom marshalling/unmarshalling functions for converting the settings keys to camelCase diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go index 023d1f27939..5e73adf373a 100644 --- a/pkg/services/ssosettings/ssosettings.go +++ b/pkg/services/ssosettings/ssosettings.go @@ -46,7 +46,8 @@ type Reloadable interface { // using the config file and/or environment variables. Used mostly for backwards compatibility. type FallbackStrategy interface { IsMatch(provider string) bool - GetProviderConfig(ctx context.Context, provider string) (any, error) + // TODO: check if GetProviderConfig can return an error + GetProviderConfig(ctx context.Context, provider string) (map[string]any, error) } // Store is a SSO settings store diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go index c51f0b7b13f..8e974da5090 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -3,11 +3,12 @@ package ssosettingsimpl import ( "context" "errors" + "fmt" + "strings" "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/login/social" ac "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/secrets" @@ -95,12 +96,12 @@ func (s *SSOSettingsService) List(ctx context.Context) ([]*models.SSOSettings, e settings := getSettingsByProvider(provider, storedSettings) if len(settings) == 0 { // If there is no data in the DB then we need to load the settings using the fallback strategy - setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + fallbackSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) if err != nil { return nil, err } - settings = append(settings, setting) + settings = append(settings, fallbackSettings) } result = append(result, settings...) } @@ -109,23 +110,16 @@ func (s *SSOSettingsService) List(ctx context.Context) ([]*models.SSOSettings, e } func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error { + var err error // TODO: also check whether the provider is configurable // Get the connector for the provider (from the reloadables) and call Validate - if isOAuthProvider(settings.Provider) { - encryptedClientSecret, err := s.secrets.Encrypt(ctx, []byte(settings.OAuthSettings.ClientSecret), secrets.WithoutScope()) - if err != nil { - return err - } - settings.OAuthSettings.ClientSecret = string(encryptedClientSecret) - } - - err := s.store.Upsert(ctx, settings) + settings.Settings, err = s.encryptSecrets(ctx, settings.Settings) if err != nil { return err } - return nil + return s.store.Upsert(ctx, settings) } func (s *SSOSettingsService) Patch(ctx context.Context, provider string, data map[string]any) error { @@ -162,16 +156,11 @@ func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Conte return nil, err } - switch settingsFromSystem := settingsFromSystem.(type) { - case *social.OAuthInfo: - return &models.SSOSettings{ - Provider: provider, - Source: models.System, - OAuthSettings: settingsFromSystem, - }, nil - default: - return nil, errors.New("could not parse settings from system") - } + return &models.SSOSettings{ + Provider: provider, + Source: models.System, + Settings: settingsFromSystem, + }, nil } func getSettingsByProvider(provider string, settings []*models.SSOSettings) []*models.SSOSettings { @@ -193,12 +182,32 @@ func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosetting return nil, false } -func isOAuthProvider(provider string) bool { - for _, oAuthProvider := range ssosettings.AllOAuthProviders { - if oAuthProvider == provider { - return true +func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) { + secretFieldPatterns := []string{"secret"} + + isSecret := func(field string) bool { + for _, v := range secretFieldPatterns { + if strings.Contains(strings.ToLower(field), strings.ToLower(v)) { + return true + } + } + return false + } + + for k, v := range settings { + if isSecret(k) { + strValue, ok := v.(string) + if !ok { + return settings, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v) + } + + encryptedSecret, err := s.secrets.Encrypt(ctx, []byte(strValue), secrets.WithoutScope()) + if err != nil { + return settings, err + } + settings[k] = string(encryptedSecret) } } - return false + return settings, nil } diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go index d909bfb6658..74e384a9bf4 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service_test.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes" @@ -31,14 +30,14 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) { name: "should return successfully", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, + Provider: "github", + Settings: map[string]any{"enabled": true}, + Source: models.DB, } }, want: &models.SSOSettings{ - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, + Provider: "github", + Settings: map[string]any{"enabled": true}, }, wantErr: false, }, @@ -53,12 +52,12 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) { setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true - env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: true} + env.fallbackStrategy.ExpectedConfig = map[string]any{"enabled": true} }, want: &models.SSOSettings{ - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.System, + Provider: "github", + Settings: map[string]any{"enabled": true}, + Source: models.System, }, wantErr: false, }, @@ -115,54 +114,54 @@ func TestSSOSettingsService_List(t *testing.T) { setup: func(env testEnv) { env.store.ExpectedSSOSettings = []*models.SSOSettings{ { - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, + Provider: "github", + Settings: map[string]any{"enabled": true}, + Source: models.DB, }, { - Provider: "okta", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.DB, + Provider: "okta", + Settings: map[string]any{"enabled": false}, + Source: models.DB, }, } env.fallbackStrategy.ExpectedIsMatch = true - env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: false} + env.fallbackStrategy.ExpectedConfig = map[string]any{"enabled": false} }, want: []*models.SSOSettings{ { - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: true}, - Source: models.DB, + Provider: "github", + Settings: map[string]any{"enabled": true}, + Source: models.DB, }, { - Provider: "okta", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.DB, + Provider: "okta", + Settings: map[string]any{"enabled": false}, + Source: models.DB, }, { - Provider: "gitlab", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "gitlab", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "generic_oauth", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "generic_oauth", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "google", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "google", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "azuread", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "azuread", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "grafana_com", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "grafana_com", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, }, wantErr: false, @@ -178,43 +177,43 @@ func TestSSOSettingsService_List(t *testing.T) { setup: func(env testEnv) { env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = true - env.fallbackStrategy.ExpectedConfig = &social.OAuthInfo{Enabled: false} + env.fallbackStrategy.ExpectedConfig = map[string]any{"enabled": false} }, want: []*models.SSOSettings{ { - Provider: "github", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "github", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "okta", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "okta", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "gitlab", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "gitlab", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "generic_oauth", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "generic_oauth", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "google", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "google", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "azuread", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "azuread", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, { - Provider: "grafana_com", - OAuthSettings: &social.OAuthInfo{Enabled: false}, - Source: models.System, + Provider: "grafana_com", + Settings: map[string]any{"enabled": false}, + Source: models.System, }, }, wantErr: false, @@ -255,15 +254,15 @@ func TestSSOSettingsService_Upsert(t *testing.T) { settings := models.SSOSettings{ Provider: "azuread", - OAuthSettings: &social.OAuthInfo{ - ClientId: "client-id", - ClientSecret: "client-secret", - Enabled: true, + Settings: map[string]any{ + "client_id": "client-id", + "client_secret": "client-secret", + "enabled": true, }, IsDeleted: false, } - env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() + env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() err := env.service.Upsert(context.Background(), settings) require.NoError(t, err) @@ -274,15 +273,15 @@ func TestSSOSettingsService_Upsert(t *testing.T) { settings := models.SSOSettings{ Provider: "azuread", - OAuthSettings: &social.OAuthInfo{ - ClientId: "client-id", - ClientSecret: "client-secret", - Enabled: true, + Settings: map[string]any{ + "client_id": "client-id", + "client_secret": "client-secret", + "enabled": true, }, IsDeleted: false, } - env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return(nil, errors.New("encryption failed")).Once() + env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once() err := env.service.Upsert(context.Background(), settings) require.Error(t, err) @@ -293,15 +292,15 @@ func TestSSOSettingsService_Upsert(t *testing.T) { settings := models.SSOSettings{ Provider: "azuread", - OAuthSettings: &social.OAuthInfo{ - ClientId: "client-id", - ClientSecret: "client-secret", - Enabled: true, + Settings: map[string]any{ + "client_id": "client-id", + "client_secret": "client-secret", + "enabled": true, }, IsDeleted: false, } - env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() + env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.store.ExpectedError = errors.New("upsert failed") err := env.service.Upsert(context.Background(), settings) diff --git a/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go b/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go index 110fcbffc44..0a8ae3a9d39 100644 --- a/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go +++ b/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go @@ -4,7 +4,7 @@ import context "context" type FakeFallbackStrategy struct { ExpectedIsMatch bool - ExpectedConfig any + ExpectedConfig map[string]any ExpectedError error } @@ -17,6 +17,6 @@ func (f *FakeFallbackStrategy) IsMatch(provider string) bool { return f.ExpectedIsMatch } -func (f *FakeFallbackStrategy) GetProviderConfig(ctx context.Context, provider string) (any, error) { +func (f *FakeFallbackStrategy) GetProviderConfig(ctx context.Context, provider string) (map[string]any, error) { return f.ExpectedConfig, f.ExpectedError } diff --git a/pkg/services/ssosettings/ssosettingstests/service_mock.go b/pkg/services/ssosettings/ssosettingstests/service_mock.go index 31efe79018e..04765452a87 100644 --- a/pkg/services/ssosettings/ssosettingstests/service_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/service_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.27.1. DO NOT EDIT. +// Code generated by mockery v2.37.1. DO NOT EDIT. package ssosettingstests @@ -120,13 +120,12 @@ func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) return r0 } -type mockConstructorTestingTNewMockService interface { +// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockService(t interface { mock.TestingT Cleanup(func()) -} - -// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockService(t mockConstructorTestingTNewMockService) *MockService { +}) *MockService { mock := &MockService{} mock.Mock.Test(t) diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go index 17d3adbe12b..55214c18fe8 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.27.1. DO NOT EDIT. +// Code generated by mockery v2.37.1. DO NOT EDIT. package ssosettingstests @@ -108,13 +108,12 @@ func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) er return r0 } -type mockConstructorTestingTNewMockStore interface { +// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStore(t interface { mock.TestingT Cleanup(func()) -} - -// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore { +}) *MockStore { mock := &MockStore{} mock.Mock.Test(t) diff --git a/pkg/services/ssosettings/strategies/oauth_strategy.go b/pkg/services/ssosettings/strategies/oauth_strategy.go index 3f43b6625dc..4420902bec2 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy.go @@ -4,23 +4,13 @@ import ( "context" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/login/social/connectors" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" ) type OAuthStrategy struct { cfg *setting.Cfg - settingsByProvider map[string]*social.OAuthInfo -} - -var extraKeysByProvider = map[string][]string{ - social.AzureADProviderName: connectors.ExtraAzureADSettingKeys, - social.GenericOAuthProviderName: connectors.ExtraGenericOAuthSettingKeys, - social.GitHubProviderName: connectors.ExtraGithubSettingKeys, - social.GrafanaComProviderName: connectors.ExtraGrafanaComSettingKeys, - social.GrafanaNetProviderName: connectors.ExtraGrafanaComSettingKeys, + settingsByProvider map[string]map[string]any } var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil) @@ -28,7 +18,7 @@ var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil) func NewOAuthStrategy(cfg *setting.Cfg) *OAuthStrategy { oauthStrategy := &OAuthStrategy{ cfg: cfg, - settingsByProvider: make(map[string]*social.OAuthInfo), + settingsByProvider: make(map[string]map[string]any), } oauthStrategy.loadAllSettings() @@ -40,7 +30,7 @@ func (s *OAuthStrategy) IsMatch(provider string) bool { return ok } -func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (any, error) { +func (s *OAuthStrategy) GetProviderConfig(_ context.Context, provider string) (map[string]any, error) { return s.settingsByProvider[provider], nil } @@ -55,49 +45,46 @@ func (s *OAuthStrategy) loadAllSettings() { } } -func (s *OAuthStrategy) loadSettingsForProvider(provider string) *social.OAuthInfo { +func (s *OAuthStrategy) loadSettingsForProvider(provider string) map[string]any { section := s.cfg.SectionWithEnvOverrides("auth." + provider) - result := &social.OAuthInfo{ - AllowAssignGrafanaAdmin: section.Key("allow_assign_grafana_admin").MustBool(false), - AllowSignup: section.Key("allow_sign_up").MustBool(false), - AllowedDomains: util.SplitString(section.Key("allowed_domains").Value()), - AllowedGroups: util.SplitString(section.Key("allowed_groups").Value()), - ApiUrl: section.Key("api_url").Value(), - AuthStyle: section.Key("auth_style").Value(), - AuthUrl: section.Key("auth_url").Value(), - AutoLogin: section.Key("auto_login").MustBool(false), - ClientId: section.Key("client_id").Value(), - ClientSecret: section.Key("client_secret").Value(), - EmailAttributeName: section.Key("email_attribute_name").Value(), - EmailAttributePath: section.Key("email_attribute_path").Value(), - EmptyScopes: section.Key("empty_scopes").MustBool(false), - Enabled: section.Key("enabled").MustBool(false), - GroupsAttributePath: section.Key("groups_attribute_path").Value(), - HostedDomain: section.Key("hosted_domain").Value(), - Icon: section.Key("icon").Value(), - Name: section.Key("name").Value(), - RoleAttributePath: section.Key("role_attribute_path").Value(), - RoleAttributeStrict: section.Key("role_attribute_strict").MustBool(false), - Scopes: util.SplitString(section.Key("scopes").Value()), - SignoutRedirectUrl: section.Key("signout_redirect_url").Value(), - SkipOrgRoleSync: section.Key("skip_org_role_sync").MustBool(false), - TeamIdsAttributePath: section.Key("team_ids_attribute_path").Value(), - TeamsUrl: section.Key("teams_url").Value(), - TlsClientCa: section.Key("tls_client_ca").Value(), - TlsClientCert: section.Key("tls_client_cert").Value(), - TlsClientKey: section.Key("tls_client_key").Value(), - TlsSkipVerify: section.Key("tls_skip_verify_insecure").MustBool(false), - TokenUrl: section.Key("token_url").Value(), - UsePKCE: section.Key("use_pkce").MustBool(false), - UseRefreshToken: section.Key("use_refresh_token").MustBool(false), - Extra: map[string]string{}, + return map[string]any{ + "client_id": section.Key("client_id").Value(), + "client_secret": section.Key("client_secret").Value(), + "scopes": section.Key("scopes").Value(), + "empty_scopes": section.Key("empty_scopes").MustBool(false), + "auth_style": section.Key("auth_style").Value(), + "auth_url": section.Key("auth_url").Value(), + "token_url": section.Key("token_url").Value(), + "api_url": section.Key("api_url").Value(), + "teams_url": section.Key("teams_url").Value(), + "enabled": section.Key("enabled").MustBool(false), + "email_attribute_name": section.Key("email_attribute_name").Value(), + "email_attribute_path": section.Key("email_attribute_path").Value(), + "role_attribute_path": section.Key("role_attribute_path").Value(), + "role_attribute_strict": section.Key("role_attribute_strict").MustBool(false), + "groups_attribute_path": section.Key("groups_attribute_path").Value(), + "team_ids_attribute_path": section.Key("team_ids_attribute_path").Value(), + "allowed_domains": section.Key("allowed_domains").Value(), + "hosted_domain": section.Key("hosted_domain").Value(), + "allow_sign_up": section.Key("allow_sign_up").MustBool(false), + "name": section.Key("name").Value(), + "icon": section.Key("icon").Value(), + "skip_org_role_sync": section.Key("skip_org_role_sync").MustBool(false), + "tls_client_cert": section.Key("tls_client_cert").Value(), + "tls_client_key": section.Key("tls_client_key").Value(), + "tls_client_ca": section.Key("tls_client_ca").Value(), + "tls_skip_verify_insecure": section.Key("tls_skip_verify_insecure").MustBool(false), + "use_pkce": section.Key("use_pkce").MustBool(false), + "use_refresh_token": section.Key("use_refresh_token").MustBool(false), + "allow_assign_grafana_admin": section.Key("allow_assign_grafana_admin").MustBool(false), + "auto_login": section.Key("auto_login").MustBool(false), + "allowed_groups": section.Key("allowed_groups").Value(), + "signout_redirect_url": section.Key("signout_redirect_url").Value(), + "allowed_organizations": section.Key("allowed_organizations").Value(), + "id_token_attribute_name": section.Key("id_token_attribute_name").Value(), + "login_attribute_path": section.Key("login_attribute_path").Value(), + "name_attribute_path": section.Key("name_attribute_path").Value(), + "team_ids": section.Key("team_ids").Value(), } - - extraFields := extraKeysByProvider[provider] - for _, key := range extraFields { - result.Extra[key] = section.Key(key).Value() - } - - return result } diff --git a/pkg/services/ssosettings/strategies/oauth_strategy_test.go b/pkg/services/ssosettings/strategies/oauth_strategy_test.go index 414f7deb9a4..4250abf339b 100644 --- a/pkg/services/ssosettings/strategies/oauth_strategy_test.go +++ b/pkg/services/ssosettings/strategies/oauth_strategy_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/ini.v1" - "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/setting" ) @@ -54,46 +53,44 @@ var ( signout_redirect_url = test_signout_redirect_url ` - expectedOAuthInfo = &social.OAuthInfo{ - Name: "OAuth", - Icon: "signin", - Enabled: true, - AllowSignup: false, - AutoLogin: true, - ClientId: "test_client_id", - ClientSecret: "test_client_secret", - Scopes: []string{"openid", "profile", "email"}, - EmptyScopes: false, - EmailAttributeName: "email:primary", - EmailAttributePath: "email", - RoleAttributePath: "role", - RoleAttributeStrict: true, - GroupsAttributePath: "groups", - TeamIdsAttributePath: "team_ids", - AuthUrl: "test_auth_url", - TokenUrl: "test_token_url", - ApiUrl: "test_api_url", - TeamsUrl: "test_teams_url", - AllowedDomains: []string{"domain1.com"}, - AllowedGroups: []string{}, - TlsSkipVerify: true, - TlsClientCert: "", - TlsClientKey: "", - TlsClientCa: "", - UsePKCE: false, - AuthStyle: "inheader", - AllowAssignGrafanaAdmin: true, - UseRefreshToken: true, - HostedDomain: "test_hosted_domain", - SkipOrgRoleSync: true, - SignoutRedirectUrl: "test_signout_redirect_url", - Extra: map[string]string{ - "allowed_organizations": "org1, org2", - "id_token_attribute_name": "id_token", - "login_attribute_path": "login", - "name_attribute_path": "name", - "team_ids": "first, second", - }, + expectedOAuthInfo = map[string]any{ + "name": "OAuth", + "icon": "signin", + "enabled": true, + "allow_sign_up": false, + "auto_login": true, + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "scopes": "[\"openid\", \"profile\", \"email\"]", + "empty_scopes": false, + "email_attribute_name": "email:primary", + "email_attribute_path": "email", + "role_attribute_path": "role", + "role_attribute_strict": true, + "groups_attribute_path": "groups", + "team_ids_attribute_path": "team_ids", + "auth_url": "test_auth_url", + "token_url": "test_token_url", + "api_url": "test_api_url", + "teams_url": "test_teams_url", + "allowed_domains": "domain1.com", + "allowed_groups": "", + "tls_skip_verify_insecure": true, + "tls_client_cert": "", + "tls_client_key": "", + "tls_client_ca": "", + "use_pkce": false, + "auth_style": "inheader", + "allow_assign_grafana_admin": true, + "use_refresh_token": true, + "hosted_domain": "test_hosted_domain", + "skip_org_role_sync": true, + "signout_redirect_url": "test_signout_redirect_url", + "allowed_organizations": "org1, org2", + "id_token_attribute_name": "id_token", + "login_attribute_path": "login", + "name_attribute_path": "name", + "team_ids": "first, second", } ) @@ -106,10 +103,7 @@ func TestGetProviderConfig_EnvVarsOnly(t *testing.T) { result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth") require.NoError(t, err) - oauthInfo, ok := result.(*social.OAuthInfo) - require.True(t, ok) - - require.Equal(t, expectedOAuthInfo, oauthInfo) + require.Equal(t, expectedOAuthInfo, result) } func TestGetProviderConfig_IniFileOnly(t *testing.T) { @@ -124,10 +118,7 @@ func TestGetProviderConfig_IniFileOnly(t *testing.T) { result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth") require.NoError(t, err) - oauthInfo, ok := result.(*social.OAuthInfo) - require.True(t, ok) - - require.Equal(t, expectedOAuthInfo, oauthInfo) + require.Equal(t, expectedOAuthInfo, result) } func TestGetProviderConfig_EnvVarsOverrideIniFileSettings(t *testing.T) { @@ -145,14 +136,11 @@ func TestGetProviderConfig_EnvVarsOverrideIniFileSettings(t *testing.T) { result, err := strategy.GetProviderConfig(context.Background(), "generic_oauth") require.NoError(t, err) - oauthInfo, ok := result.(*social.OAuthInfo) - require.True(t, ok) + expectedOAuthInfoWithOverrides := expectedOAuthInfo + expectedOAuthInfoWithOverrides["enabled"] = false + expectedOAuthInfoWithOverrides["skip_org_role_sync"] = false - expectedOAuthInfoWithOverrides := *expectedOAuthInfo - expectedOAuthInfoWithOverrides.Enabled = false - expectedOAuthInfoWithOverrides.SkipOrgRoleSync = false - - require.Equal(t, expectedOAuthInfoWithOverrides, *oauthInfo) + require.Equal(t, expectedOAuthInfoWithOverrides, result) } func setupEnvVars(t *testing.T) {