diff --git a/pkg/services/ssosettings/database/database.go b/pkg/services/ssosettings/database/database.go index 8d971e8a423..ea3782a8bbf 100644 --- a/pkg/services/ssosettings/database/database.go +++ b/pkg/services/ssosettings/database/database.go @@ -12,6 +12,11 @@ import ( "github.com/grafana/grafana/pkg/services/ssosettings/models" ) +const ( + isDeletedColumn = "is_deleted" + updatedColumn = "updated" +) + type SSOSettingsStore struct { sqlStore db.DB log log.Logger @@ -27,12 +32,17 @@ func ProvideStore(sqlStore db.DB) *SSOSettingsStore { var _ ssosettings.Store = (*SSOSettingsStore)(nil) func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) { - result := models.SSOSettings{Provider: provider} - err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - var err error - sess.Table("sso_setting") - found, err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Get(&result) + if provider == "" { + return nil, ssosettings.ErrNotFound + } + result := models.SSOSettings{ + Provider: provider, + IsDeleted: false, + } + + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + found, err := sess.UseBool(isDeletedColumn).Get(&result) if err != nil { return err } @@ -53,10 +63,13 @@ func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SS func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, error) { result := make([]*models.SSOSettings, 0) - err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - sess.Table("sso_setting") - err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result) + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + condition := &models.SSOSettings{ + IsDeleted: false, + } + + err := sess.UseBool(isDeletedColumn).Find(&result, condition) if err != nil { return err } @@ -72,12 +85,17 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, err } func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error { + if settings.Provider == "" { + return ssosettings.ErrNotFound + } + return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { existing := &models.SSOSettings{ Provider: settings.Provider, IsDeleted: false, } - found, err := sess.UseBool("is_deleted").Exist(existing) + + found, err := sess.UseBool(isDeletedColumn).Exist(existing) if err != nil { return err } @@ -90,7 +108,7 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin Updated: now, IsDeleted: false, } - _, err = sess.UseBool("is_deleted").Update(updated, existing) + _, err = sess.UseBool(isDeletedColumn).Update(updated, existing) } else { _, err = sess.Insert(&models.SSOSettings{ ID: uuid.New().String(), @@ -105,18 +123,18 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettin }) } -func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[string]interface{}) error { - panic("not implemented") // TODO: Implement -} - func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error { + if provider == "" { + return ssosettings.ErrNotFound + } + return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { existing := &models.SSOSettings{ Provider: provider, IsDeleted: false, } - found, err := sess.UseBool("is_deleted").Get(existing) + found, err := sess.UseBool(isDeletedColumn).Get(existing) if err != nil { return err } @@ -128,7 +146,7 @@ func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error { existing.Updated = time.Now().UTC() existing.IsDeleted = true - _, err = sess.ID(existing.ID).MustCols("updated", "is_deleted").Update(existing) + _, err = sess.ID(existing.ID).MustCols(updatedColumn, isDeletedColumn).Update(existing) return err }) } diff --git a/pkg/services/ssosettings/database/database_test.go b/pkg/services/ssosettings/database/database_test.go index 5e8f1c6efe1..bb51ec46bbc 100644 --- a/pkg/services/ssosettings/database/database_test.go +++ b/pkg/services/ssosettings/database/database_test.go @@ -33,7 +33,7 @@ func TestIntegrationGetSSOSettings(t *testing.T) { template := models.SSOSettings{ Settings: map[string]any{"enabled": true}, } - err := populateSSOSettings(sqlStore, template, "azuread") + err := populateSSOSettings(sqlStore, template, "azuread", "github", "google") require.NoError(t, err) } @@ -60,10 +60,23 @@ func TestIntegrationGetSSOSettings(t *testing.T) { t.Run("returns not found if the SSO setting is soft deleted for the specified provider", func(t *testing.T) { setup() - err := ssoSettingsStore.Delete(context.Background(), "azuread") + + provider := "okta" + template := models.SSOSettings{ + Settings: map[string]any{"enabled": true}, + IsDeleted: true, + } + err := populateSSOSettings(sqlStore, template, provider) require.NoError(t, err) - _, err = ssoSettingsStore.Get(context.Background(), "azuread") + _, err = ssoSettingsStore.Get(context.Background(), provider) + require.ErrorAs(t, err, &ssosettings.ErrNotFound) + }) + + t.Run("returns not found if the specified provider is empty", func(t *testing.T) { + setup() + + _, err := ssoSettingsStore.Get(context.Background(), "") require.ErrorAs(t, err, &ssosettings.ErrNotFound) }) } @@ -218,6 +231,33 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { require.EqualValues(t, template.Settings, existing.Settings) } }) + + t.Run("fails if the provider is empty", func(t *testing.T) { + setup() + + template := models.SSOSettings{ + Settings: map[string]any{ + "enabled": true, + "client_id": "azuread-client", + "client_secret": "this-is-a-secret", + }, + IsDeleted: true, + } + err := populateSSOSettings(sqlStore, template, "azuread") + require.NoError(t, err) + + settings := models.SSOSettings{ + Provider: "", + Settings: map[string]any{ + "enabled": true, + "client_id": "new-client", + }, + } + + err = ssoSettingsStore.Upsert(context.Background(), settings) + require.Error(t, err) + require.ErrorIs(t, err, ssosettings.ErrNotFound) + }) } func TestIntegrationListSSOSettings(t *testing.T) { @@ -231,31 +271,58 @@ func TestIntegrationListSSOSettings(t *testing.T) { setup := func() { sqlStore = db.InitTestDB(t) ssoSettingsStore = ProvideStore(sqlStore) - - template := models.SSOSettings{ - Settings: map[string]any{ - "enabled": true, - }, - } - err := populateSSOSettings(sqlStore, template, "azuread") - require.NoError(t, err) - - template = models.SSOSettings{ - Settings: map[string]any{ - "enabled": true, - }, - } - err = populateSSOSettings(sqlStore, template, "okta") - require.NoError(t, err) } t.Run("returns every SSO settings successfully", func(t *testing.T) { setup() + providers := []string{"azuread", "okta", "github"} + settings := models.SSOSettings{ + Settings: map[string]any{ + "enabled": true, + "client_id": "the_client_id", + }, + IsDeleted: false, + } + err := populateSSOSettings(sqlStore, settings, providers...) + require.NoError(t, err) + + deleted := models.SSOSettings{ + Settings: map[string]any{ + "enabled": false, + }, + IsDeleted: true, + } + err = populateSSOSettings(sqlStore, deleted, "google", "gitlab", "okta") + require.NoError(t, err) + list, err := ssoSettingsStore.List(context.Background()) require.NoError(t, err) - require.Equal(t, 2, len(list)) + require.Len(t, list, len(providers)) + + for _, item := range list { + require.Contains(t, providers, item.Provider) + require.EqualValues(t, settings.Settings, item.Settings) + } + }) + + t.Run("returns empty list if no settings are found", func(t *testing.T) { + setup() + + deleted := models.SSOSettings{ + Settings: map[string]any{ + "enabled": false, + }, + IsDeleted: true, + } + err := populateSSOSettings(sqlStore, deleted, "google", "gitlab", "okta") + require.NoError(t, err) + + list, err := ssoSettingsStore.List(context.Background()) + + require.NoError(t, err) + require.Len(t, list, 0) }) } @@ -362,6 +429,28 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { require.EqualValues(t, 1, deleted) require.EqualValues(t, 1, notDeleted) }) + + t.Run("return not found if the provider is empty", func(t *testing.T) { + setup() + + providers := []string{"github", "google", "okta"} + template := models.SSOSettings{ + Settings: map[string]any{ + "enabled": true, + }, + } + err := populateSSOSettings(sqlStore, template, providers...) + require.NoError(t, err) + + err = ssoSettingsStore.Delete(context.Background(), "") + require.Error(t, err) + require.ErrorIs(t, err, ssosettings.ErrNotFound) + + deleted, notDeleted, err := getSSOSettingsCountByDeleted(sqlStore) + require.NoError(t, err) + require.EqualValues(t, 0, deleted) + require.EqualValues(t, len(providers), notDeleted) + }) } func populateSSOSettings(sqlStore *sqlstore.SQLStore, template models.SSOSettings, providers ...string) error { diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go index 59aa8a44a1b..ff1ffc99c95 100644 --- a/pkg/services/ssosettings/ssosettings.go +++ b/pkg/services/ssosettings/ssosettings.go @@ -61,6 +61,5 @@ type Store interface { Get(ctx context.Context, provider string) (*models.SSOSettings, error) List(ctx context.Context) ([]*models.SSOSettings, error) Upsert(ctx context.Context, settings models.SSOSettings) error - Patch(ctx context.Context, provider string, data map[string]any) error Delete(ctx context.Context, provider string) error } diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go index 9b04828519b..0db6238da07 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.39.2. DO NOT EDIT. package ssosettingstests @@ -92,24 +92,6 @@ func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) { return r0, r1 } -// Patch provides a mock function with given fields: ctx, provider, data -func (_m *MockStore) Patch(ctx context.Context, provider string, data map[string]interface{}) error { - ret := _m.Called(ctx, provider, data) - - if len(ret) == 0 { - panic("no return value specified for Patch") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok { - r0 = rf(ctx, provider, data) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // Upsert provides a mock function with given fields: ctx, settings func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error { ret := _m.Called(ctx, settings)