Refactor SSOSettings to use types (#78675)

* refactor SSOSettings to use types

* test struct

* refactor SSOSettings struct to use types

* fix database tests

* fix populateSSOSettings() to accept an SSOSettings param

* fix all tests from the database layer

* handle errors for converting to/from SSOSettings

* add json tag on OAuthInfo fields

* use continue instead of if/else

* add the source field to SSOSettingsDTO conversion

* remove omitempty from json tags in OAuthInfo struct
This commit is contained in:
Mihai Doarna 2023-11-29 18:02:04 +02:00 committed by GitHub
parent 931c8e99b9
commit 2e2b1cd9e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 445 additions and 383 deletions

View File

@ -65,7 +65,7 @@ type keySetJWKS struct {
}
func NewAzureADProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (*SocialAzureAD, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -197,9 +197,9 @@ func convertIniSectionToMap(sec *ini.Section) map[string]any {
return mappedSettings
}
// createOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure
// CreateOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure
// it puts all extra key values into OAuthInfo's Extra map
func createOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) {
func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) {
emptyStrToSliceDecodeHook := func(from reflect.Type, to reflect.Type, data any) (any, error) {
if from.Kind() == reflect.String && to.Kind() == reflect.Slice {
strData, ok := data.(string)

View File

@ -33,7 +33,7 @@ token_url = test_token_url
api_url = test_api_url
teams_url = test_teams_url
allowed_domains = domain1.com
allowed_groups =
allowed_groups =
team_ids = first, second
allowed_organizations = org1, org2
tls_skip_verify_insecure = true
@ -96,7 +96,7 @@ signout_redirect_url = https://oauth.com/signout?post_logout_redirect_uri=https:
}
settingsKVs := convertIniSectionToMap(iniFile.Section("test"))
oauthInfo, err := createOAuthInfoFromKeyValues(settingsKVs)
oauthInfo, err := CreateOAuthInfoFromKeyValues(settingsKVs)
require.NoError(t, err)
require.Equal(t, expectedOAuthInfo, oauthInfo)

View File

@ -37,7 +37,7 @@ type SocialGenericOAuth struct {
}
func NewGenericOAuthProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGenericOAuth, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -50,7 +50,7 @@ var (
)
func NewGitHubProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGithub, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -50,7 +50,7 @@ type userData struct {
}
func NewGitLabProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGitlab, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -37,7 +37,7 @@ type googleUserData struct {
}
func NewGoogleProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGoogle, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -29,7 +29,7 @@ type OrgRecord struct {
}
func NewGrafanaComProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGrafanaCom, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -44,7 +44,7 @@ type OktaClaims struct {
}
func NewOktaProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialOkta, error) {
info, err := createOAuthInfoFromKeyValues(settings)
info, err := CreateOAuthInfoFromKeyValues(settings)
if err != nil {
return nil, err
}

View File

@ -44,38 +44,38 @@ type SocialService struct {
}
type OAuthInfo struct {
ApiUrl string `mapstructure:"api_url" toml:"api_url"`
AuthUrl string `mapstructure:"auth_url" toml:"auth_url"`
AuthStyle string `mapstructure:"auth_style" toml:"auth_style"`
ClientId string `mapstructure:"client_id" toml:"client_id"`
ClientSecret string `mapstructure:"client_secret" toml:"-"`
EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name"`
EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path"`
EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes"`
GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path"`
HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain"`
Icon string `mapstructure:"icon" toml:"icon"`
Name string `mapstructure:"name" toml:"name"`
RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path"`
TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path"`
TeamsUrl string `mapstructure:"teams_url" toml:"teams_url"`
TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca"`
TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert"`
TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key"`
TokenUrl string `mapstructure:"token_url" toml:"token_url"`
AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains"`
AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups"`
Scopes []string `mapstructure:"scopes" toml:"scopes"`
AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin"`
AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up"`
AutoLogin bool `mapstructure:"auto_login" toml:"auto_login"`
Enabled bool `mapstructure:"enabled" toml:"enabled"`
RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict"`
TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure"`
UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce"`
UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token"`
SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url"`
Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty"`
ApiUrl string `mapstructure:"api_url" toml:"api_url" json:"apiUrl"`
AuthUrl string `mapstructure:"auth_url" toml:"auth_url" json:"authUrl"`
AuthStyle string `mapstructure:"auth_style" toml:"auth_style" json:"authStyle"`
ClientId string `mapstructure:"client_id" toml:"client_id" json:"clientId"`
ClientSecret string `mapstructure:"client_secret" toml:"-" json:"clientSecret"`
EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name" json:"emailAttributeName"`
EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path" json:"emailAttributePath"`
EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes" json:"emptyScopes"`
GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path" json:"groupsAttributePath"`
HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain" json:"hostedDomain"`
Icon string `mapstructure:"icon" toml:"icon" json:"icon"`
Name string `mapstructure:"name" toml:"name" json:"name"`
RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path" json:"roleAttributePath"`
TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path" json:"teamIdsAttributePath"`
TeamsUrl string `mapstructure:"teams_url" toml:"teams_url" json:"teamsUrl"`
TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca" json:"tlsClientCa"`
TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert" json:"tlsClientCert"`
TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key" json:"tlsClientKey"`
TokenUrl string `mapstructure:"token_url" toml:"token_url" json:"tokenUrl"`
AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains" json:"allowedDomains"`
AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups" json:"allowedGroups"`
Scopes []string `mapstructure:"scopes" toml:"scopes" json:"scopes"`
AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin" json:"allowAssignGrafanaAdmin"`
AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up" json:"allowSignup"`
AutoLogin bool `mapstructure:"auto_login" toml:"auto_login" json:"autoLogin"`
Enabled bool `mapstructure:"enabled" toml:"enabled" json:"enabled"`
RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict" json:"roleAttributeStrict"`
TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure" json:"tlsSkipVerify"`
UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce" json:"usePKCE"`
UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token" json:"useRefreshToken"`
SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url" json:"signoutRedirectUrl"`
Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty" json:"extra"`
}
func ProvideService(cfg *setting.Cfg,
@ -97,7 +97,7 @@ func ProvideService(cfg *setting.Cfg,
sec := cfg.Raw.Section("auth." + name)
settingsKVs := convertIniSectionToMap(sec)
info, err := createOAuthInfoFromKeyValues(settingsKVs)
info, err := CreateOAuthInfoFromKeyValues(settingsKVs)
if err != nil {
ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", name)
continue

View File

@ -61,7 +61,18 @@ func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Re
return response.Error(500, "Failed to get providers", err)
}
return response.JSON(http.StatusOK, providers)
dtos := make([]*models.SSOSettingsDTO, 0)
for _, provider := range providers {
dto, err := provider.ToSSOSettingsDTO()
if err != nil {
api.Log.Warn("Failed to convert SSO Settings for provider " + provider.Provider)
continue
}
dtos = append(dtos, dto)
}
return response.JSON(http.StatusOK, dtos)
}
func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Response {
@ -75,7 +86,12 @@ func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Respons
return response.Error(http.StatusNotFound, "The provider was not found", err)
}
return response.JSON(http.StatusOK, settings)
dto, err := settings.ToSSOSettingsDTO()
if err != nil {
return response.Error(http.StatusInternalServerError, "The provider is invalid", err)
}
return response.JSON(http.StatusOK, dto)
}
func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Response {
@ -84,12 +100,19 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
return response.Error(http.StatusBadRequest, "Missing key", nil)
}
var newSettings models.SSOSetting
if err := web.Bind(c.Req, &newSettings); err != nil {
var settingsDTO models.SSOSettingsDTO
if err := web.Bind(c.Req, &settingsDTO); err != nil {
return response.Error(http.StatusBadRequest, "Failed to parse request body", err)
}
err := api.SSOSettingsService.Upsert(c.Req.Context(), key, newSettings.Settings)
settings, err := settingsDTO.ToSSOSettings()
if err != nil {
return response.Error(http.StatusBadRequest, "Invalid request body", err)
}
settings.Provider = key
err = api.SSOSettingsService.Upsert(c.Req.Context(), *settings)
// TODO: first check whether the error is referring to validation errors
// other error

View File

@ -31,8 +31,8 @@ func ProvideStore(sqlStore db.DB) *SSOSettingsStore {
var _ ssosettings.Store = (*SSOSettingsStore)(nil)
func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
result := models.SSOSetting{Provider: provider}
func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
result := models.SSOSettingsDTO{Provider: provider}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
var err error
sess.Table("sso_setting")
@ -53,14 +53,19 @@ func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SS
return nil, err
}
return &result, nil
dto, err := result.ToSSOSettings()
if err != nil {
return nil, err
}
return dto, nil
}
func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, error) {
result := make([]*models.SSOSetting, 0)
func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
dtos := make([]*models.SSOSettingsDTO, 0)
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
sess.Table("sso_setting")
err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result)
err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&dtos)
if err != nil {
return err
@ -73,13 +78,29 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, erro
return nil, err
}
return result, nil
settings := make([]*models.SSOSettings, 0)
for _, dto := range dtos {
item, err := dto.ToSSOSettings()
if err != nil {
s.log.Warn("Failed to convert DB settings to SSOSettings for provider " + dto.Provider)
continue
}
settings = append(settings, item)
}
return settings, nil
}
func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
dto, err := settings.ToSSOSettingsDTO()
if err != nil {
return err
}
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
existing := &models.SSOSetting{
Provider: provider,
existing := &models.SSOSettingsDTO{
Provider: dto.Provider,
IsDeleted: false,
}
found, err := sess.UseBool("is_deleted").Exist(existing)
@ -90,17 +111,17 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map
now := timeNow().UTC()
if found {
updated := &models.SSOSetting{
Settings: data,
updated := &models.SSOSettingsDTO{
Settings: dto.Settings,
Updated: now,
IsDeleted: false,
}
_, err = sess.UseBool("is_deleted").Update(updated, existing)
} else {
_, err = sess.Insert(&models.SSOSetting{
_, err = sess.Insert(&models.SSOSettingsDTO{
ID: uuid.New().String(),
Provider: provider,
Settings: data,
Provider: dto.Provider,
Settings: dto.Settings,
Created: now,
Updated: now,
})
@ -116,7 +137,7 @@ func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[
func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error {
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
existing := &models.SSOSetting{
existing := &models.SSOSettingsDTO{
Provider: provider,
IsDeleted: false,
}

View File

@ -7,9 +7,9 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/ssosettings"
"github.com/grafana/grafana/pkg/services/ssosettings/models"
@ -27,24 +27,27 @@ func TestIntegrationGetSSOSettings(t *testing.T) {
sqlStore = db.InitTestDB(t)
ssoSettingsStore = ProvideStore(sqlStore)
err := insertSSOSetting(ssoSettingsStore, "azuread", nil)
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
}
err := populateSSOSettings(sqlStore, template, "azuread")
require.NoError(t, err)
}
t.Run("returns existing SSO settings", func(t *testing.T) {
setup()
expected := &models.SSOSetting{
Provider: "azuread",
Settings: map[string]interface{}{
"enabled": true,
},
expected := &models.SSOSettings{
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{Enabled: true},
}
actual, err := ssoSettingsStore.Get(context.Background(), "azuread")
require.NoError(t, err)
require.True(t, maps.Equal(expected.Settings, actual.Settings))
require.Equal(t, expected.OAuthSettings, actual.OAuthSettings)
})
t.Run("returns not found if the SSO setting is missing for the specified provider", func(t *testing.T) {
@ -83,18 +86,21 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
mockTimeNow(time.Now())
defer resetTimeNow()
provider := "azuread"
settings := map[string]interface{}{
"enabled": true,
"client_id": "azuread-client",
settings := models.SSOSettings{
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "azuread-client",
},
}
err := ssoSettingsStore.Upsert(context.Background(), provider, settings)
err := ssoSettingsStore.Upsert(context.Background(), settings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
require.NoError(t, err)
require.Equal(t, settings, actual.Settings)
require.EqualValues(t, settings.OAuthSettings, actual.OAuthSettings)
require.NotEmpty(t, actual.ID)
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created))
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
@ -111,25 +117,30 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
defer resetTimeNow()
provider := "github"
settings := map[string]interface{}{
"enabled": true,
"client_id": "github-client",
"client_secret": "this-is-a-secret",
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "github-client",
ClientSecret: "this-is-a-secret",
},
}
err := populateSSOSettings(sqlStore, settings, false, provider)
err := populateSSOSettings(sqlStore, template, provider)
require.NoError(t, err)
newSettings := map[string]interface{}{
"enabled": true,
"client_id": "new-github-client",
"client_secret": "this-is-a-new-secret",
newSettings := models.SSOSettings{
Provider: provider,
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "new-github-client",
ClientSecret: "this-is-a-new-secret",
},
}
err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings)
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
require.NoError(t, err)
require.Equal(t, newSettings, actual.Settings)
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
deleted, notDeleted, err := getSSOSettingsCountByDeleted(sqlStore)
@ -145,32 +156,38 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
defer resetTimeNow()
provider := "azuread"
settings := map[string]interface{}{
"enabled": true,
"client_id": "azuread-client",
"client_secret": "this-is-a-secret",
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "azuread-client",
ClientSecret: "this-is-a-secret",
},
IsDeleted: true,
}
err := populateSSOSettings(sqlStore, settings, true, provider)
err := populateSSOSettings(sqlStore, template, provider)
require.NoError(t, err)
newSettings := map[string]interface{}{
"enabled": true,
"client_id": "new-azuread-client",
"client_secret": "this-is-a-new-secret",
newSettings := models.SSOSettings{
Provider: provider,
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "new-azuread-client",
ClientSecret: "this-is-a-new-secret",
},
}
err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings)
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
require.NoError(t, err)
require.Equal(t, newSettings, actual.Settings)
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created))
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
old, err := getSSOSettingsByProvider(sqlStore, provider, true)
require.NoError(t, err)
require.Equal(t, settings, old.Settings)
require.Equal(t, template.OAuthSettings, old.OAuthSettings)
})
t.Run("replaces the settings only for the specified provider leaving the other provider's settings unchanged", func(t *testing.T) {
@ -180,31 +197,36 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
defer resetTimeNow()
providers := []string{"github", "gitlab", "google"}
settings := map[string]interface{}{
"enabled": true,
"client_id": "my-client",
"client_secret": "this-is-a-secret",
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "my-client",
ClientSecret: "this-is-a-secret",
},
}
err := populateSSOSettings(sqlStore, settings, false, providers...)
err := populateSSOSettings(sqlStore, template, providers...)
require.NoError(t, err)
newSettings := map[string]interface{}{
"enabled": true,
"client_id": "my-new-client",
"client_secret": "this-is-a-new-secret",
newSettings := models.SSOSettings{
Provider: providers[0],
OAuthSettings: &social.OAuthInfo{
Enabled: true,
ClientId: "my-new-client",
ClientSecret: "this-is-my-new-secret",
},
}
err = ssoSettingsStore.Upsert(context.Background(), providers[0], newSettings)
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
require.NoError(t, err)
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
require.NoError(t, err)
require.Equal(t, newSettings, actual.Settings)
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
for index := 1; index < len(providers); index++ {
existing, err := getSSOSettingsByProvider(sqlStore, providers[index], false)
require.NoError(t, err)
require.Equal(t, settings, existing.Settings)
require.EqualValues(t, template.OAuthSettings, existing.OAuthSettings)
}
})
}
@ -221,14 +243,20 @@ func TestIntegrationListSSOSettings(t *testing.T) {
sqlStore = db.InitTestDB(t)
ssoSettingsStore = ProvideStore(sqlStore)
err := insertSSOSetting(ssoSettingsStore, "azuread", map[string]interface{}{
"enabled": true,
})
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
}
err := populateSSOSettings(sqlStore, template, "azuread")
require.NoError(t, err)
err = insertSSOSetting(ssoSettingsStore, "okta", map[string]interface{}{
"enabled": false,
})
template = models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: false,
},
}
err = populateSSOSettings(sqlStore, template, "okta")
require.NoError(t, err)
}
@ -259,8 +287,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
setup()
providers := []string{"azuread", "github", "google"}
err := populateSSOSettings(sqlStore, nil, false, providers...)
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
}
err := populateSSOSettings(sqlStore, template, providers...)
require.NoError(t, err)
err = ssoSettingsStore.Delete(context.Background(), providers[0])
@ -277,8 +309,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
providers := []string{"github", "google", "okta"}
invalidProvider := "azuread"
err := populateSSOSettings(sqlStore, nil, false, providers...)
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
}
err := populateSSOSettings(sqlStore, template, providers...)
require.NoError(t, err)
err = ssoSettingsStore.Delete(context.Background(), invalidProvider)
@ -295,8 +331,13 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
setup()
providers := []string{"azuread", "github", "google"}
err := populateSSOSettings(sqlStore, nil, true, providers...)
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
IsDeleted: true,
}
err := populateSSOSettings(sqlStore, template, providers...)
require.NoError(t, err)
err = ssoSettingsStore.Delete(context.Background(), providers[0])
@ -313,11 +354,15 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
setup()
provider := "azuread"
template := models.SSOSettings{
OAuthSettings: &social.OAuthInfo{
Enabled: true,
},
}
// insert sso for the same provider 2 times in the database
err := populateSSOSettings(sqlStore, nil, false, provider)
err := populateSSOSettings(sqlStore, template, provider)
require.NoError(t, err)
err = populateSSOSettings(sqlStore, nil, false, provider)
err = populateSSOSettings(sqlStore, template, provider)
require.NoError(t, err)
err = ssoSettingsStore.Delete(context.Background(), provider)
@ -330,25 +375,19 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
})
}
func insertSSOSetting(ssoSettingsStore ssosettings.Store, provider string, settings map[string]interface{}) error {
if settings == nil {
settings = map[string]interface{}{
"enabled": true,
}
}
return ssoSettingsStore.Upsert(context.Background(), provider, settings)
}
func populateSSOSettings(sqlStore *sqlstore.SQLStore, settings map[string]interface{}, deleted bool, providers ...string) error {
func populateSSOSettings(sqlStore *sqlstore.SQLStore, template models.SSOSettings, providers ...string) error {
return sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
for _, provider := range providers {
_, err := sess.Insert(&models.SSOSetting{
ID: uuid.New().String(),
Provider: provider,
Settings: settings,
Created: timeNow().UTC(),
IsDeleted: deleted,
})
template.Provider = provider
template.ID = uuid.New().String()
template.Created = timeNow().UTC()
dto, err := template.ToSSOSettingsDTO()
if err != nil {
return err
}
_, err = sess.Insert(dto)
if err != nil {
return err
}
@ -370,8 +409,8 @@ func getSSOSettingsCountByDeleted(sqlStore *sqlstore.SQLStore) (deleted, notDele
return
}
func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSetting, error) {
var model models.SSOSetting
func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSettings, error) {
var model models.SSOSettingsDTO
var err error
err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
@ -379,7 +418,16 @@ func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, dele
return err
})
return &model, err
if err != nil {
return nil, err
}
settings, err := model.ToSSOSettings()
if err != nil {
return nil, err
}
return settings, err
}
func mockTimeNow(timeSeed time.Time) {

View File

@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/grafana/grafana/pkg/services/featuremgmt/strcase"
"github.com/grafana/grafana/pkg/login/social"
)
type SettingsSource int
@ -26,8 +26,18 @@ func (s SettingsSource) MarshalJSON() ([]byte, error) {
}
}
type SSOSetting struct {
ID string `xorm:"id pk" json:"-"`
type SSOSettings struct {
ID string
Provider string
OAuthSettings *social.OAuthInfo
Created time.Time
Updated time.Time
IsDeleted bool
Source SettingsSource
}
type SSOSettingsDTO struct {
ID string `xorm:"id pk" json:"id"`
Provider string `xorm:"provider" json:"provider"`
Settings map[string]interface{} `xorm:"settings" json:"settings"`
Created time.Time `xorm:"created" json:"-"`
@ -37,51 +47,51 @@ type SSOSetting struct {
}
// TableName returns the table name (needed for Xorm)
func (s SSOSetting) TableName() string {
func (s SSOSettingsDTO) TableName() string {
return "sso_setting"
}
// MarshalJSON implements the json.Marshaler interface and converts the s.Settings from map[string]any to map[string]any in camelCase
func (s SSOSetting) MarshalJSON() ([]byte, error) {
type Alias SSOSetting
aux := &struct {
*Alias
}{
Alias: (*Alias)(&s),
func (s SSOSettingsDTO) ToSSOSettings() (*SSOSettings, error) {
settingsEncoded, err := json.Marshal(s.Settings)
if err != nil {
return nil, err
}
settings := make(map[string]any)
for k, v := range aux.Settings {
settings[strcase.ToLowerCamel(k)] = v
var settings social.OAuthInfo
err = json.Unmarshal(settingsEncoded, &settings)
if err != nil {
return nil, err
}
aux.Settings = settings
return json.Marshal(aux)
return &SSOSettings{
ID: s.ID,
Provider: s.Provider,
OAuthSettings: &settings,
Created: s.Created,
Updated: s.Updated,
IsDeleted: s.IsDeleted,
}, nil
}
// UnmarshalJSON implements the json.Unmarshaler interface and converts the settings from map[string]any camelCase to map[string]interface{} snake_case
func (s *SSOSetting) UnmarshalJSON(data []byte) error {
type Alias SSOSetting
aux := &struct {
*Alias
}{
Alias: (*Alias)(s),
func (s SSOSettings) ToSSOSettingsDTO() (*SSOSettingsDTO, error) {
settingsEncoded, err := json.Marshal(s.OAuthSettings)
if err != nil {
return nil, err
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
var settings map[string]interface{}
err = json.Unmarshal(settingsEncoded, &settings)
if err != nil {
return nil, err
}
settings := make(map[string]any)
for k, v := range aux.Settings {
settings[strcase.ToSnake(k)] = v
}
s.Settings = settings
return nil
}
type SSOSettingsResponse struct {
Settings map[string]interface{} `json:"settings"`
Provider string `json:"type"`
return &SSOSettingsDTO{
ID: s.ID,
Provider: s.Provider,
Settings: settings,
Created: s.Created,
Updated: s.Updated,
IsDeleted: s.IsDeleted,
Source: s.Source,
}, nil
}

View File

@ -20,11 +20,11 @@ var (
//go:generate mockery --name Service --structname MockService --outpkg ssosettingstests --filename service_mock.go --output ./ssosettingstests/
type Service interface {
// List returns all SSO settings from DB and config files
List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error)
List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error)
// GetForProvider returns the SSO settings for a given provider (DB or config file)
GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error)
GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error)
// Upsert creates or updates the SSO settings for a given provider
Upsert(ctx context.Context, provider string, data map[string]interface{}) error
Upsert(ctx context.Context, settings models.SSOSettings) error
// Delete deletes the SSO settings for a given provider (soft delete)
Delete(ctx context.Context, provider string) error
// Patch updates the specified SSO settings (key-value pairs) for a given provider
@ -52,9 +52,9 @@ type FallbackStrategy interface {
//
//go:generate mockery --name Store --structname MockStore --outpkg ssosettingstests --filename store_mock.go --output ./ssosettingstests/
type Store interface {
Get(ctx context.Context, provider string) (*models.SSOSetting, error)
List(ctx context.Context) ([]*models.SSOSetting, error)
Upsert(ctx context.Context, provider string, data map[string]interface{}) error
Get(ctx context.Context, provider string) (*models.SSOSettings, error)
List(ctx context.Context) ([]*models.SSOSettings, error)
Upsert(ctx context.Context, settings models.SSOSettings) error
Patch(ctx context.Context, provider string, data map[string]interface{}) error
Delete(ctx context.Context, provider string) error
}

View File

@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/featuremgmt"
@ -55,29 +56,29 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
var _ ssosettings.Service = (*SSOSettingsService)(nil)
func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) {
dto, err := s.store.Get(ctx, provider)
func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) {
storeSettings, err := s.store.Get(ctx, provider)
if errors.Is(err, ssosettings.ErrNotFound) {
setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
settings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
if err != nil {
return nil, err
}
return setting, nil
return settings, nil
}
if err != nil {
return nil, err
}
dto.Source = models.DB
storeSettings.Source = models.DB
return dto, nil
return storeSettings, nil
}
func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) {
result := make([]*models.SSOSetting, 0, len(ssosettings.AllOAuthProviders))
func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) {
result := make([]*models.SSOSettings, 0, len(ssosettings.AllOAuthProviders))
storedSettings, err := s.store.List(ctx)
if err != nil {
@ -98,12 +99,12 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
settings := getSettingsByProvider(provider, storedSettings)
if len(settings) == 0 {
// If there is no data in the DB then we need to load the settings using the fallback strategy
setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
fallbackSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
if err != nil {
return nil, err
}
settings = append(settings, setting)
settings = append(settings, fallbackSettings)
}
result = append(result, settings...)
}
@ -111,9 +112,9 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
return result, nil
}
func (s *SSOSettingsService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
// TODO: validation (configurable provider? Contains the required fields? etc)
err := s.store.Upsert(ctx, provider, data)
err := s.store.Upsert(ctx, settings)
if err != nil {
return err
}
@ -140,7 +141,7 @@ func (s *SSOSettingsService) RegisterFallbackStrategy(providerRegex string, stra
s.fbStrategies = append(s.fbStrategies, strategy)
}
func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSetting, error) {
func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSettings, error) {
loadStrategy, ok := s.getFallBackstrategyFor(provider)
if !ok {
return nil, errors.New("no fallback strategy found for provider: " + provider)
@ -151,18 +152,23 @@ func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Conte
return nil, err
}
return &models.SSOSetting{
Provider: provider,
Source: models.System,
Settings: settingsFromSystem,
oAuthInfo, err := social.CreateOAuthInfoFromKeyValues(settingsFromSystem)
if err != nil {
return nil, err
}
return &models.SSOSettings{
Provider: provider,
Source: models.System,
OAuthSettings: oAuthInfo,
}, nil
}
func getSettingsByProvider(provider string, settings []*models.SSOSetting) []*models.SSOSetting {
result := make([]*models.SSOSetting, 0)
for _, setting := range settings {
if setting.Provider == provider {
result = append(result, setting)
func getSettingsByProvider(provider string, settings []*models.SSOSettings) []*models.SSOSettings {
result := make([]*models.SSOSettings, 0)
for _, item := range settings {
if item.Provider == provider {
result = append(result, item)
}
}
return result

View File

@ -7,6 +7,7 @@ import (
"testing"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
"github.com/grafana/grafana/pkg/services/auth/identity"
@ -22,25 +23,21 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) {
testCases := []struct {
name string
setup func(env testEnv)
want *models.SSOSetting
want *models.SSOSettings
wantErr bool
}{
{
name: "should return successfully",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSetting{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.DB,
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
}
},
want: &models.SSOSetting{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
want: &models.SSOSettings{
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
},
wantErr: false,
},
@ -59,12 +56,10 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) {
"enabled": true,
}
},
want: &models.SSOSetting{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.System,
want: &models.SSOSettings{
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.System,
},
wantErr: false,
},
@ -136,26 +131,22 @@ func TestSSOSettingsService_List(t *testing.T) {
name string
setup func(env testEnv)
identity identity.Requester
want []*models.SSOSetting
want []*models.SSOSettings
wantErr bool
}{
{
name: "should return successfully",
setup: func(env testEnv) {
env.store.ExpectedSSOSettings = []*models.SSOSetting{
env.store.ExpectedSSOSettings = []*models.SSOSettings{
{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.DB,
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
},
{
Provider: "okta",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.DB,
Provider: "okta",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.DB,
},
}
env.fallbackStrategy.ExpectedIsMatch = true
@ -164,55 +155,41 @@ func TestSSOSettingsService_List(t *testing.T) {
}
},
identity: defaultIdentity,
want: []*models.SSOSetting{
want: []*models.SSOSettings{
{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.DB,
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
},
{
Provider: "okta",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.DB,
Provider: "okta",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.DB,
},
{
Provider: "gitlab",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "gitlab",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "generic_oauth",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "generic_oauth",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "google",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "google",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "azuread",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "grafana_com",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "grafana_com",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
},
wantErr: false,
@ -220,20 +197,16 @@ func TestSSOSettingsService_List(t *testing.T) {
{
name: "should return the settings that the user has access to",
setup: func(env testEnv) {
env.store.ExpectedSSOSettings = []*models.SSOSetting{
env.store.ExpectedSSOSettings = []*models.SSOSettings{
{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.DB,
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
},
{
Provider: "okta",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.DB,
Provider: "okta",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
},
}
env.fallbackStrategy.ExpectedIsMatch = true
@ -242,20 +215,16 @@ func TestSSOSettingsService_List(t *testing.T) {
}
},
identity: scopedIdentity,
want: []*models.SSOSetting{
want: []*models.SSOSettings{
{
Provider: "github",
Settings: map[string]interface{}{
"enabled": true,
},
Source: models.DB,
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: true},
Source: models.DB,
},
{
Provider: "azuread",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
},
wantErr: false,
@ -270,62 +239,48 @@ func TestSSOSettingsService_List(t *testing.T) {
{
name: "should use the fallback strategy if store returns empty list",
setup: func(env testEnv) {
env.store.ExpectedSSOSettings = []*models.SSOSetting{}
env.store.ExpectedSSOSettings = []*models.SSOSettings{}
env.fallbackStrategy.ExpectedIsMatch = true
env.fallbackStrategy.ExpectedConfig = map[string]interface{}{
"enabled": false,
}
},
identity: defaultIdentity,
want: []*models.SSOSetting{
want: []*models.SSOSettings{
{
Provider: "github",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "github",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "okta",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "okta",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "gitlab",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "gitlab",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "generic_oauth",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "generic_oauth",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "google",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "google",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "azuread",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
{
Provider: "grafana_com",
Settings: map[string]interface{}{
"enabled": false,
},
Source: models.System,
Provider: "grafana_com",
OAuthSettings: &social.OAuthInfo{Enabled: false},
Source: models.System,
},
},
wantErr: false,
@ -333,7 +288,7 @@ func TestSSOSettingsService_List(t *testing.T) {
{
name: "should return error if any of the fallback strategies was not found",
setup: func(env testEnv) {
env.store.ExpectedSSOSettings = []*models.SSOSetting{}
env.store.ExpectedSSOSettings = []*models.SSOSettings{}
env.fallbackStrategy.ExpectedIsMatch = false
},
identity: defaultIdentity,

View File

@ -33,19 +33,19 @@ func (_m *MockService) Delete(ctx context.Context, provider string) error {
}
// GetForProvider provides a mock function with given fields: ctx, provider
func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) {
func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) {
ret := _m.Called(ctx, provider)
var r0 *models.SSOSetting
var r0 *models.SSOSettings
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok {
return rf(ctx, provider)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok {
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok {
r0 = rf(ctx, provider)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SSOSetting)
r0 = ret.Get(0).(*models.SSOSettings)
}
}
@ -59,19 +59,19 @@ func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*mo
}
// List provides a mock function with given fields: ctx, requester
func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) {
func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) {
ret := _m.Called(ctx, requester)
var r0 []*models.SSOSetting
var r0 []*models.SSOSettings
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSetting, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSettings, error)); ok {
return rf(ctx, requester)
}
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSetting); ok {
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSettings); ok {
r0 = rf(ctx, requester)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SSOSetting)
r0 = ret.Get(0).([]*models.SSOSettings)
}
}
@ -108,13 +108,13 @@ func (_m *MockService) Reload(ctx context.Context, provider string) {
_m.Called(ctx, provider)
}
// Upsert provides a mock function with given fields: ctx, provider, data
func (_m *MockService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
ret := _m.Called(ctx, provider, data)
// Upsert provides a mock function with given fields: ctx, settings
func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error {
ret := _m.Called(ctx, settings)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok {
r0 = rf(ctx, provider, data)
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
r0 = rf(ctx, settings)
} else {
r0 = ret.Error(0)
}

View File

@ -10,8 +10,8 @@ import (
var _ ssosettings.Store = (*FakeStore)(nil)
type FakeStore struct {
ExpectedSSOSetting *models.SSOSetting
ExpectedSSOSettings []*models.SSOSetting
ExpectedSSOSetting *models.SSOSettings
ExpectedSSOSettings []*models.SSOSettings
ExpectedError error
}
@ -19,15 +19,15 @@ func NewFakeStore() *FakeStore {
return &FakeStore{}
}
func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
return f.ExpectedSSOSetting, f.ExpectedError
}
func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSetting, error) {
func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
return f.ExpectedSSOSettings, f.ExpectedError
}
func (f *FakeStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
return f.ExpectedError
}

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.27.1. DO NOT EDIT.
// Code generated by mockery v2.37.1. DO NOT EDIT.
package ssosettingstests
@ -29,19 +29,19 @@ func (_m *MockStore) Delete(ctx context.Context, provider string) error {
}
// Get provides a mock function with given fields: ctx, provider
func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
ret := _m.Called(ctx, provider)
var r0 *models.SSOSetting
var r0 *models.SSOSettings
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok {
return rf(ctx, provider)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok {
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok {
r0 = rf(ctx, provider)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SSOSetting)
r0 = ret.Get(0).(*models.SSOSettings)
}
}
@ -54,20 +54,20 @@ func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetti
return r0, r1
}
// GetAll provides a mock function with given fields: ctx
func (_m *MockStore) GetAll(ctx context.Context) ([]*models.SSOSetting, error) {
// List provides a mock function with given fields: ctx
func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
ret := _m.Called(ctx)
var r0 []*models.SSOSetting
var r0 []*models.SSOSettings
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSetting, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSettings, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSetting); ok {
if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSettings); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SSOSetting)
r0 = ret.Get(0).([]*models.SSOSettings)
}
}
@ -94,13 +94,13 @@ func (_m *MockStore) Patch(ctx context.Context, provider string, data map[string
return r0
}
// Upsert provides a mock function with given fields: ctx, provider, data
func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
ret := _m.Called(ctx, provider, data)
// Upsert provides a mock function with given fields: ctx, settings
func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
ret := _m.Called(ctx, settings)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok {
r0 = rf(ctx, provider, data)
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
r0 = rf(ctx, settings)
} else {
r0 = ret.Error(0)
}
@ -108,13 +108,12 @@ func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[strin
return r0
}
type mockConstructorTestingTNewMockStore interface {
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockStore(t interface {
mock.TestingT
Cleanup(func())
}
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore {
}) *MockStore {
mock := &MockStore{}
mock.Mock.Test(t)