mirror of
https://github.com/grafana/grafana.git
synced 2025-01-27 16:57:14 -06:00
Refactor SSOSettings to use types (#78675)
* refactor SSOSettings to use types * test struct * refactor SSOSettings struct to use types * fix database tests * fix populateSSOSettings() to accept an SSOSettings param * fix all tests from the database layer * handle errors for converting to/from SSOSettings * add json tag on OAuthInfo fields * use continue instead of if/else * add the source field to SSOSettingsDTO conversion * remove omitempty from json tags in OAuthInfo struct
This commit is contained in:
parent
931c8e99b9
commit
2e2b1cd9e4
@ -65,7 +65,7 @@ type keySetJWKS struct {
|
||||
}
|
||||
|
||||
func NewAzureADProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (*SocialAzureAD, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -197,9 +197,9 @@ func convertIniSectionToMap(sec *ini.Section) map[string]any {
|
||||
return mappedSettings
|
||||
}
|
||||
|
||||
// createOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure
|
||||
// CreateOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure
|
||||
// it puts all extra key values into OAuthInfo's Extra map
|
||||
func createOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) {
|
||||
func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) {
|
||||
emptyStrToSliceDecodeHook := func(from reflect.Type, to reflect.Type, data any) (any, error) {
|
||||
if from.Kind() == reflect.String && to.Kind() == reflect.Slice {
|
||||
strData, ok := data.(string)
|
||||
|
@ -33,7 +33,7 @@ token_url = test_token_url
|
||||
api_url = test_api_url
|
||||
teams_url = test_teams_url
|
||||
allowed_domains = domain1.com
|
||||
allowed_groups =
|
||||
allowed_groups =
|
||||
team_ids = first, second
|
||||
allowed_organizations = org1, org2
|
||||
tls_skip_verify_insecure = true
|
||||
@ -96,7 +96,7 @@ signout_redirect_url = https://oauth.com/signout?post_logout_redirect_uri=https:
|
||||
}
|
||||
|
||||
settingsKVs := convertIniSectionToMap(iniFile.Section("test"))
|
||||
oauthInfo, err := createOAuthInfoFromKeyValues(settingsKVs)
|
||||
oauthInfo, err := CreateOAuthInfoFromKeyValues(settingsKVs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, expectedOAuthInfo, oauthInfo)
|
||||
|
@ -37,7 +37,7 @@ type SocialGenericOAuth struct {
|
||||
}
|
||||
|
||||
func NewGenericOAuthProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGenericOAuth, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ var (
|
||||
)
|
||||
|
||||
func NewGitHubProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGithub, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ type userData struct {
|
||||
}
|
||||
|
||||
func NewGitLabProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGitlab, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ type googleUserData struct {
|
||||
}
|
||||
|
||||
func NewGoogleProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGoogle, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ type OrgRecord struct {
|
||||
}
|
||||
|
||||
func NewGrafanaComProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGrafanaCom, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ type OktaClaims struct {
|
||||
}
|
||||
|
||||
func NewOktaProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialOkta, error) {
|
||||
info, err := createOAuthInfoFromKeyValues(settings)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -44,38 +44,38 @@ type SocialService struct {
|
||||
}
|
||||
|
||||
type OAuthInfo struct {
|
||||
ApiUrl string `mapstructure:"api_url" toml:"api_url"`
|
||||
AuthUrl string `mapstructure:"auth_url" toml:"auth_url"`
|
||||
AuthStyle string `mapstructure:"auth_style" toml:"auth_style"`
|
||||
ClientId string `mapstructure:"client_id" toml:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret" toml:"-"`
|
||||
EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name"`
|
||||
EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path"`
|
||||
EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes"`
|
||||
GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path"`
|
||||
HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain"`
|
||||
Icon string `mapstructure:"icon" toml:"icon"`
|
||||
Name string `mapstructure:"name" toml:"name"`
|
||||
RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path"`
|
||||
TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path"`
|
||||
TeamsUrl string `mapstructure:"teams_url" toml:"teams_url"`
|
||||
TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca"`
|
||||
TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert"`
|
||||
TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key"`
|
||||
TokenUrl string `mapstructure:"token_url" toml:"token_url"`
|
||||
AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains"`
|
||||
AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups"`
|
||||
Scopes []string `mapstructure:"scopes" toml:"scopes"`
|
||||
AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin"`
|
||||
AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up"`
|
||||
AutoLogin bool `mapstructure:"auto_login" toml:"auto_login"`
|
||||
Enabled bool `mapstructure:"enabled" toml:"enabled"`
|
||||
RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict"`
|
||||
TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure"`
|
||||
UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce"`
|
||||
UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token"`
|
||||
SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url"`
|
||||
Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty"`
|
||||
ApiUrl string `mapstructure:"api_url" toml:"api_url" json:"apiUrl"`
|
||||
AuthUrl string `mapstructure:"auth_url" toml:"auth_url" json:"authUrl"`
|
||||
AuthStyle string `mapstructure:"auth_style" toml:"auth_style" json:"authStyle"`
|
||||
ClientId string `mapstructure:"client_id" toml:"client_id" json:"clientId"`
|
||||
ClientSecret string `mapstructure:"client_secret" toml:"-" json:"clientSecret"`
|
||||
EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name" json:"emailAttributeName"`
|
||||
EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path" json:"emailAttributePath"`
|
||||
EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes" json:"emptyScopes"`
|
||||
GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path" json:"groupsAttributePath"`
|
||||
HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain" json:"hostedDomain"`
|
||||
Icon string `mapstructure:"icon" toml:"icon" json:"icon"`
|
||||
Name string `mapstructure:"name" toml:"name" json:"name"`
|
||||
RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path" json:"roleAttributePath"`
|
||||
TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path" json:"teamIdsAttributePath"`
|
||||
TeamsUrl string `mapstructure:"teams_url" toml:"teams_url" json:"teamsUrl"`
|
||||
TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca" json:"tlsClientCa"`
|
||||
TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert" json:"tlsClientCert"`
|
||||
TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key" json:"tlsClientKey"`
|
||||
TokenUrl string `mapstructure:"token_url" toml:"token_url" json:"tokenUrl"`
|
||||
AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains" json:"allowedDomains"`
|
||||
AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups" json:"allowedGroups"`
|
||||
Scopes []string `mapstructure:"scopes" toml:"scopes" json:"scopes"`
|
||||
AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin" json:"allowAssignGrafanaAdmin"`
|
||||
AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up" json:"allowSignup"`
|
||||
AutoLogin bool `mapstructure:"auto_login" toml:"auto_login" json:"autoLogin"`
|
||||
Enabled bool `mapstructure:"enabled" toml:"enabled" json:"enabled"`
|
||||
RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict" json:"roleAttributeStrict"`
|
||||
TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure" json:"tlsSkipVerify"`
|
||||
UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce" json:"usePKCE"`
|
||||
UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token" json:"useRefreshToken"`
|
||||
SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url" json:"signoutRedirectUrl"`
|
||||
Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty" json:"extra"`
|
||||
}
|
||||
|
||||
func ProvideService(cfg *setting.Cfg,
|
||||
@ -97,7 +97,7 @@ func ProvideService(cfg *setting.Cfg,
|
||||
sec := cfg.Raw.Section("auth." + name)
|
||||
|
||||
settingsKVs := convertIniSectionToMap(sec)
|
||||
info, err := createOAuthInfoFromKeyValues(settingsKVs)
|
||||
info, err := CreateOAuthInfoFromKeyValues(settingsKVs)
|
||||
if err != nil {
|
||||
ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", name)
|
||||
continue
|
||||
|
@ -61,7 +61,18 @@ func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Re
|
||||
return response.Error(500, "Failed to get providers", err)
|
||||
}
|
||||
|
||||
return response.JSON(http.StatusOK, providers)
|
||||
dtos := make([]*models.SSOSettingsDTO, 0)
|
||||
for _, provider := range providers {
|
||||
dto, err := provider.ToSSOSettingsDTO()
|
||||
if err != nil {
|
||||
api.Log.Warn("Failed to convert SSO Settings for provider " + provider.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
dtos = append(dtos, dto)
|
||||
}
|
||||
|
||||
return response.JSON(http.StatusOK, dtos)
|
||||
}
|
||||
|
||||
func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Response {
|
||||
@ -75,7 +86,12 @@ func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Respons
|
||||
return response.Error(http.StatusNotFound, "The provider was not found", err)
|
||||
}
|
||||
|
||||
return response.JSON(http.StatusOK, settings)
|
||||
dto, err := settings.ToSSOSettingsDTO()
|
||||
if err != nil {
|
||||
return response.Error(http.StatusInternalServerError, "The provider is invalid", err)
|
||||
}
|
||||
|
||||
return response.JSON(http.StatusOK, dto)
|
||||
}
|
||||
|
||||
func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Response {
|
||||
@ -84,12 +100,19 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
|
||||
return response.Error(http.StatusBadRequest, "Missing key", nil)
|
||||
}
|
||||
|
||||
var newSettings models.SSOSetting
|
||||
if err := web.Bind(c.Req, &newSettings); err != nil {
|
||||
var settingsDTO models.SSOSettingsDTO
|
||||
if err := web.Bind(c.Req, &settingsDTO); err != nil {
|
||||
return response.Error(http.StatusBadRequest, "Failed to parse request body", err)
|
||||
}
|
||||
|
||||
err := api.SSOSettingsService.Upsert(c.Req.Context(), key, newSettings.Settings)
|
||||
settings, err := settingsDTO.ToSSOSettings()
|
||||
if err != nil {
|
||||
return response.Error(http.StatusBadRequest, "Invalid request body", err)
|
||||
}
|
||||
|
||||
settings.Provider = key
|
||||
|
||||
err = api.SSOSettingsService.Upsert(c.Req.Context(), *settings)
|
||||
// TODO: first check whether the error is referring to validation errors
|
||||
|
||||
// other error
|
||||
|
@ -31,8 +31,8 @@ func ProvideStore(sqlStore db.DB) *SSOSettingsStore {
|
||||
|
||||
var _ ssosettings.Store = (*SSOSettingsStore)(nil)
|
||||
|
||||
func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
result := models.SSOSetting{Provider: provider}
|
||||
func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
result := models.SSOSettingsDTO{Provider: provider}
|
||||
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
var err error
|
||||
sess.Table("sso_setting")
|
||||
@ -53,14 +53,19 @@ func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SS
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
dto, err := result.ToSSOSettings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dto, nil
|
||||
}
|
||||
|
||||
func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, error) {
|
||||
result := make([]*models.SSOSetting, 0)
|
||||
func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
||||
dtos := make([]*models.SSOSettingsDTO, 0)
|
||||
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
sess.Table("sso_setting")
|
||||
err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result)
|
||||
err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&dtos)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -73,13 +78,29 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
settings := make([]*models.SSOSettings, 0)
|
||||
for _, dto := range dtos {
|
||||
item, err := dto.ToSSOSettings()
|
||||
if err != nil {
|
||||
s.log.Warn("Failed to convert DB settings to SSOSettings for provider " + dto.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
settings = append(settings, item)
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
|
||||
func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
dto, err := settings.ToSSOSettingsDTO()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
existing := &models.SSOSetting{
|
||||
Provider: provider,
|
||||
existing := &models.SSOSettingsDTO{
|
||||
Provider: dto.Provider,
|
||||
IsDeleted: false,
|
||||
}
|
||||
found, err := sess.UseBool("is_deleted").Exist(existing)
|
||||
@ -90,17 +111,17 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map
|
||||
now := timeNow().UTC()
|
||||
|
||||
if found {
|
||||
updated := &models.SSOSetting{
|
||||
Settings: data,
|
||||
updated := &models.SSOSettingsDTO{
|
||||
Settings: dto.Settings,
|
||||
Updated: now,
|
||||
IsDeleted: false,
|
||||
}
|
||||
_, err = sess.UseBool("is_deleted").Update(updated, existing)
|
||||
} else {
|
||||
_, err = sess.Insert(&models.SSOSetting{
|
||||
_, err = sess.Insert(&models.SSOSettingsDTO{
|
||||
ID: uuid.New().String(),
|
||||
Provider: provider,
|
||||
Settings: data,
|
||||
Provider: dto.Provider,
|
||||
Settings: dto.Settings,
|
||||
Created: now,
|
||||
Updated: now,
|
||||
})
|
||||
@ -116,7 +137,7 @@ func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[
|
||||
|
||||
func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error {
|
||||
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
existing := &models.SSOSetting{
|
||||
existing := &models.SSOSettingsDTO{
|
||||
Provider: provider,
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
@ -7,9 +7,9 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||
@ -27,24 +27,27 @@ func TestIntegrationGetSSOSettings(t *testing.T) {
|
||||
sqlStore = db.InitTestDB(t)
|
||||
ssoSettingsStore = ProvideStore(sqlStore)
|
||||
|
||||
err := insertSSOSetting(ssoSettingsStore, "azuread", nil)
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, template, "azuread")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("returns existing SSO settings", func(t *testing.T) {
|
||||
setup()
|
||||
|
||||
expected := &models.SSOSetting{
|
||||
Provider: "azuread",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
expected := &models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
}
|
||||
|
||||
actual, err := ssoSettingsStore.Get(context.Background(), "azuread")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, maps.Equal(expected.Settings, actual.Settings))
|
||||
require.Equal(t, expected.OAuthSettings, actual.OAuthSettings)
|
||||
})
|
||||
|
||||
t.Run("returns not found if the SSO setting is missing for the specified provider", func(t *testing.T) {
|
||||
@ -83,18 +86,21 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
||||
mockTimeNow(time.Now())
|
||||
defer resetTimeNow()
|
||||
|
||||
provider := "azuread"
|
||||
settings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "azuread-client",
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "azuread-client",
|
||||
},
|
||||
}
|
||||
|
||||
err := ssoSettingsStore.Upsert(context.Background(), provider, settings)
|
||||
err := ssoSettingsStore.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
||||
actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, settings, actual.Settings)
|
||||
require.EqualValues(t, settings.OAuthSettings, actual.OAuthSettings)
|
||||
require.NotEmpty(t, actual.ID)
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created))
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
|
||||
|
||||
@ -111,25 +117,30 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
||||
defer resetTimeNow()
|
||||
|
||||
provider := "github"
|
||||
settings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "github-client",
|
||||
"client_secret": "this-is-a-secret",
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "github-client",
|
||||
ClientSecret: "this-is-a-secret",
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, settings, false, provider)
|
||||
err := populateSSOSettings(sqlStore, template, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSettings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "new-github-client",
|
||||
"client_secret": "this-is-a-new-secret",
|
||||
newSettings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "new-github-client",
|
||||
ClientSecret: "this-is-a-new-secret",
|
||||
},
|
||||
}
|
||||
err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings)
|
||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newSettings, actual.Settings)
|
||||
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
|
||||
|
||||
deleted, notDeleted, err := getSSOSettingsCountByDeleted(sqlStore)
|
||||
@ -145,32 +156,38 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
||||
defer resetTimeNow()
|
||||
|
||||
provider := "azuread"
|
||||
settings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "azuread-client",
|
||||
"client_secret": "this-is-a-secret",
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "azuread-client",
|
||||
ClientSecret: "this-is-a-secret",
|
||||
},
|
||||
IsDeleted: true,
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, settings, true, provider)
|
||||
err := populateSSOSettings(sqlStore, template, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSettings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "new-azuread-client",
|
||||
"client_secret": "this-is-a-new-secret",
|
||||
newSettings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "new-azuread-client",
|
||||
ClientSecret: "this-is-a-new-secret",
|
||||
},
|
||||
}
|
||||
|
||||
err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings)
|
||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := getSSOSettingsByProvider(sqlStore, provider, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newSettings, actual.Settings)
|
||||
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created))
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
|
||||
|
||||
old, err := getSSOSettingsByProvider(sqlStore, provider, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, settings, old.Settings)
|
||||
require.Equal(t, template.OAuthSettings, old.OAuthSettings)
|
||||
})
|
||||
|
||||
t.Run("replaces the settings only for the specified provider leaving the other provider's settings unchanged", func(t *testing.T) {
|
||||
@ -180,31 +197,36 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) {
|
||||
defer resetTimeNow()
|
||||
|
||||
providers := []string{"github", "gitlab", "google"}
|
||||
settings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "my-client",
|
||||
"client_secret": "this-is-a-secret",
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "my-client",
|
||||
ClientSecret: "this-is-a-secret",
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, settings, false, providers...)
|
||||
err := populateSSOSettings(sqlStore, template, providers...)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSettings := map[string]interface{}{
|
||||
"enabled": true,
|
||||
"client_id": "my-new-client",
|
||||
"client_secret": "this-is-a-new-secret",
|
||||
newSettings := models.SSOSettings{
|
||||
Provider: providers[0],
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
ClientId: "my-new-client",
|
||||
ClientSecret: "this-is-my-new-secret",
|
||||
},
|
||||
}
|
||||
err = ssoSettingsStore.Upsert(context.Background(), providers[0], newSettings)
|
||||
err = ssoSettingsStore.Upsert(context.Background(), newSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newSettings, actual.Settings)
|
||||
require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings)
|
||||
require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated))
|
||||
|
||||
for index := 1; index < len(providers); index++ {
|
||||
existing, err := getSSOSettingsByProvider(sqlStore, providers[index], false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, settings, existing.Settings)
|
||||
require.EqualValues(t, template.OAuthSettings, existing.OAuthSettings)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -221,14 +243,20 @@ func TestIntegrationListSSOSettings(t *testing.T) {
|
||||
sqlStore = db.InitTestDB(t)
|
||||
ssoSettingsStore = ProvideStore(sqlStore)
|
||||
|
||||
err := insertSSOSetting(ssoSettingsStore, "azuread", map[string]interface{}{
|
||||
"enabled": true,
|
||||
})
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, template, "azuread")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = insertSSOSetting(ssoSettingsStore, "okta", map[string]interface{}{
|
||||
"enabled": false,
|
||||
})
|
||||
template = models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
err = populateSSOSettings(sqlStore, template, "okta")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -259,8 +287,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
|
||||
setup()
|
||||
|
||||
providers := []string{"azuread", "github", "google"}
|
||||
|
||||
err := populateSSOSettings(sqlStore, nil, false, providers...)
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, template, providers...)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ssoSettingsStore.Delete(context.Background(), providers[0])
|
||||
@ -277,8 +309,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
|
||||
|
||||
providers := []string{"github", "google", "okta"}
|
||||
invalidProvider := "azuread"
|
||||
|
||||
err := populateSSOSettings(sqlStore, nil, false, providers...)
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, template, providers...)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ssoSettingsStore.Delete(context.Background(), invalidProvider)
|
||||
@ -295,8 +331,13 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
|
||||
setup()
|
||||
|
||||
providers := []string{"azuread", "github", "google"}
|
||||
|
||||
err := populateSSOSettings(sqlStore, nil, true, providers...)
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
IsDeleted: true,
|
||||
}
|
||||
err := populateSSOSettings(sqlStore, template, providers...)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ssoSettingsStore.Delete(context.Background(), providers[0])
|
||||
@ -313,11 +354,15 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
|
||||
setup()
|
||||
|
||||
provider := "azuread"
|
||||
|
||||
template := models.SSOSettings{
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
// insert sso for the same provider 2 times in the database
|
||||
err := populateSSOSettings(sqlStore, nil, false, provider)
|
||||
err := populateSSOSettings(sqlStore, template, provider)
|
||||
require.NoError(t, err)
|
||||
err = populateSSOSettings(sqlStore, nil, false, provider)
|
||||
err = populateSSOSettings(sqlStore, template, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ssoSettingsStore.Delete(context.Background(), provider)
|
||||
@ -330,25 +375,19 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func insertSSOSetting(ssoSettingsStore ssosettings.Store, provider string, settings map[string]interface{}) error {
|
||||
if settings == nil {
|
||||
settings = map[string]interface{}{
|
||||
"enabled": true,
|
||||
}
|
||||
}
|
||||
return ssoSettingsStore.Upsert(context.Background(), provider, settings)
|
||||
}
|
||||
|
||||
func populateSSOSettings(sqlStore *sqlstore.SQLStore, settings map[string]interface{}, deleted bool, providers ...string) error {
|
||||
func populateSSOSettings(sqlStore *sqlstore.SQLStore, template models.SSOSettings, providers ...string) error {
|
||||
return sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
|
||||
for _, provider := range providers {
|
||||
_, err := sess.Insert(&models.SSOSetting{
|
||||
ID: uuid.New().String(),
|
||||
Provider: provider,
|
||||
Settings: settings,
|
||||
Created: timeNow().UTC(),
|
||||
IsDeleted: deleted,
|
||||
})
|
||||
template.Provider = provider
|
||||
template.ID = uuid.New().String()
|
||||
template.Created = timeNow().UTC()
|
||||
|
||||
dto, err := template.ToSSOSettingsDTO()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = sess.Insert(dto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -370,8 +409,8 @@ func getSSOSettingsCountByDeleted(sqlStore *sqlstore.SQLStore) (deleted, notDele
|
||||
return
|
||||
}
|
||||
|
||||
func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSetting, error) {
|
||||
var model models.SSOSetting
|
||||
func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSettings, error) {
|
||||
var model models.SSOSettingsDTO
|
||||
var err error
|
||||
|
||||
err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
|
||||
@ -379,7 +418,16 @@ func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, dele
|
||||
return err
|
||||
})
|
||||
|
||||
return &model, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := model.ToSSOSettings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return settings, err
|
||||
}
|
||||
|
||||
func mockTimeNow(timeSeed time.Time) {
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt/strcase"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
)
|
||||
|
||||
type SettingsSource int
|
||||
@ -26,8 +26,18 @@ func (s SettingsSource) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
type SSOSetting struct {
|
||||
ID string `xorm:"id pk" json:"-"`
|
||||
type SSOSettings struct {
|
||||
ID string
|
||||
Provider string
|
||||
OAuthSettings *social.OAuthInfo
|
||||
Created time.Time
|
||||
Updated time.Time
|
||||
IsDeleted bool
|
||||
Source SettingsSource
|
||||
}
|
||||
|
||||
type SSOSettingsDTO struct {
|
||||
ID string `xorm:"id pk" json:"id"`
|
||||
Provider string `xorm:"provider" json:"provider"`
|
||||
Settings map[string]interface{} `xorm:"settings" json:"settings"`
|
||||
Created time.Time `xorm:"created" json:"-"`
|
||||
@ -37,51 +47,51 @@ type SSOSetting struct {
|
||||
}
|
||||
|
||||
// TableName returns the table name (needed for Xorm)
|
||||
func (s SSOSetting) TableName() string {
|
||||
func (s SSOSettingsDTO) TableName() string {
|
||||
return "sso_setting"
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface and converts the s.Settings from map[string]any to map[string]any in camelCase
|
||||
func (s SSOSetting) MarshalJSON() ([]byte, error) {
|
||||
type Alias SSOSetting
|
||||
aux := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(&s),
|
||||
func (s SSOSettingsDTO) ToSSOSettings() (*SSOSettings, error) {
|
||||
settingsEncoded, err := json.Marshal(s.Settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings := make(map[string]any)
|
||||
for k, v := range aux.Settings {
|
||||
settings[strcase.ToLowerCamel(k)] = v
|
||||
var settings social.OAuthInfo
|
||||
err = json.Unmarshal(settingsEncoded, &settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aux.Settings = settings
|
||||
return json.Marshal(aux)
|
||||
return &SSOSettings{
|
||||
ID: s.ID,
|
||||
Provider: s.Provider,
|
||||
OAuthSettings: &settings,
|
||||
Created: s.Created,
|
||||
Updated: s.Updated,
|
||||
IsDeleted: s.IsDeleted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface and converts the settings from map[string]any camelCase to map[string]interface{} snake_case
|
||||
func (s *SSOSetting) UnmarshalJSON(data []byte) error {
|
||||
type Alias SSOSetting
|
||||
aux := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(s),
|
||||
func (s SSOSettings) ToSSOSettingsDTO() (*SSOSettingsDTO, error) {
|
||||
settingsEncoded, err := json.Marshal(s.OAuthSettings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
var settings map[string]interface{}
|
||||
err = json.Unmarshal(settingsEncoded, &settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings := make(map[string]any)
|
||||
for k, v := range aux.Settings {
|
||||
settings[strcase.ToSnake(k)] = v
|
||||
}
|
||||
|
||||
s.Settings = settings
|
||||
return nil
|
||||
}
|
||||
|
||||
type SSOSettingsResponse struct {
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
Provider string `json:"type"`
|
||||
return &SSOSettingsDTO{
|
||||
ID: s.ID,
|
||||
Provider: s.Provider,
|
||||
Settings: settings,
|
||||
Created: s.Created,
|
||||
Updated: s.Updated,
|
||||
IsDeleted: s.IsDeleted,
|
||||
Source: s.Source,
|
||||
}, nil
|
||||
}
|
||||
|
@ -20,11 +20,11 @@ var (
|
||||
//go:generate mockery --name Service --structname MockService --outpkg ssosettingstests --filename service_mock.go --output ./ssosettingstests/
|
||||
type Service interface {
|
||||
// List returns all SSO settings from DB and config files
|
||||
List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error)
|
||||
List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error)
|
||||
// GetForProvider returns the SSO settings for a given provider (DB or config file)
|
||||
GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error)
|
||||
GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error)
|
||||
// Upsert creates or updates the SSO settings for a given provider
|
||||
Upsert(ctx context.Context, provider string, data map[string]interface{}) error
|
||||
Upsert(ctx context.Context, settings models.SSOSettings) error
|
||||
// Delete deletes the SSO settings for a given provider (soft delete)
|
||||
Delete(ctx context.Context, provider string) error
|
||||
// Patch updates the specified SSO settings (key-value pairs) for a given provider
|
||||
@ -52,9 +52,9 @@ type FallbackStrategy interface {
|
||||
//
|
||||
//go:generate mockery --name Store --structname MockStore --outpkg ssosettingstests --filename store_mock.go --output ./ssosettingstests/
|
||||
type Store interface {
|
||||
Get(ctx context.Context, provider string) (*models.SSOSetting, error)
|
||||
List(ctx context.Context) ([]*models.SSOSetting, error)
|
||||
Upsert(ctx context.Context, provider string, data map[string]interface{}) error
|
||||
Get(ctx context.Context, provider string) (*models.SSOSettings, error)
|
||||
List(ctx context.Context) ([]*models.SSOSettings, error)
|
||||
Upsert(ctx context.Context, settings models.SSOSettings) error
|
||||
Patch(ctx context.Context, provider string, data map[string]interface{}) error
|
||||
Delete(ctx context.Context, provider string) error
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
@ -55,29 +56,29 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
||||
|
||||
var _ ssosettings.Service = (*SSOSettingsService)(nil)
|
||||
|
||||
func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
dto, err := s.store.Get(ctx, provider)
|
||||
func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
storeSettings, err := s.store.Get(ctx, provider)
|
||||
|
||||
if errors.Is(err, ssosettings.ErrNotFound) {
|
||||
setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
|
||||
settings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return setting, nil
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dto.Source = models.DB
|
||||
storeSettings.Source = models.DB
|
||||
|
||||
return dto, nil
|
||||
return storeSettings, nil
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) {
|
||||
result := make([]*models.SSOSetting, 0, len(ssosettings.AllOAuthProviders))
|
||||
func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) {
|
||||
result := make([]*models.SSOSettings, 0, len(ssosettings.AllOAuthProviders))
|
||||
storedSettings, err := s.store.List(ctx)
|
||||
|
||||
if err != nil {
|
||||
@ -98,12 +99,12 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
|
||||
settings := getSettingsByProvider(provider, storedSettings)
|
||||
if len(settings) == 0 {
|
||||
// If there is no data in the DB then we need to load the settings using the fallback strategy
|
||||
setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
|
||||
fallbackSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings = append(settings, setting)
|
||||
settings = append(settings, fallbackSettings)
|
||||
}
|
||||
result = append(result, settings...)
|
||||
}
|
||||
@ -111,9 +112,9 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
|
||||
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
// TODO: validation (configurable provider? Contains the required fields? etc)
|
||||
err := s.store.Upsert(ctx, provider, data)
|
||||
err := s.store.Upsert(ctx, settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -140,7 +141,7 @@ func (s *SSOSettingsService) RegisterFallbackStrategy(providerRegex string, stra
|
||||
s.fbStrategies = append(s.fbStrategies, strategy)
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
loadStrategy, ok := s.getFallBackstrategyFor(provider)
|
||||
if !ok {
|
||||
return nil, errors.New("no fallback strategy found for provider: " + provider)
|
||||
@ -151,18 +152,23 @@ func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Conte
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.SSOSetting{
|
||||
Provider: provider,
|
||||
Source: models.System,
|
||||
Settings: settingsFromSystem,
|
||||
oAuthInfo, err := social.CreateOAuthInfoFromKeyValues(settingsFromSystem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.SSOSettings{
|
||||
Provider: provider,
|
||||
Source: models.System,
|
||||
OAuthSettings: oAuthInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSettingsByProvider(provider string, settings []*models.SSOSetting) []*models.SSOSetting {
|
||||
result := make([]*models.SSOSetting, 0)
|
||||
for _, setting := range settings {
|
||||
if setting.Provider == provider {
|
||||
result = append(result, setting)
|
||||
func getSettingsByProvider(provider string, settings []*models.SSOSettings) []*models.SSOSettings {
|
||||
result := make([]*models.SSOSettings, 0)
|
||||
for _, item := range settings {
|
||||
if item.Provider == provider {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
@ -22,25 +23,21 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(env testEnv)
|
||||
want *models.SSOSetting
|
||||
want *models.SSOSettings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "should return successfully",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSetting = &models.SSOSetting{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.DB,
|
||||
env.store.ExpectedSSOSetting = &models.SSOSettings{
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
}
|
||||
},
|
||||
want: &models.SSOSetting{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
want: &models.SSOSettings{
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -59,12 +56,10 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) {
|
||||
"enabled": true,
|
||||
}
|
||||
},
|
||||
want: &models.SSOSetting{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.System,
|
||||
want: &models.SSOSettings{
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.System,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -136,26 +131,22 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
name string
|
||||
setup func(env testEnv)
|
||||
identity identity.Requester
|
||||
want []*models.SSOSetting
|
||||
want []*models.SSOSettings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "should return successfully",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSetting{
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSettings{
|
||||
{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
},
|
||||
{
|
||||
Provider: "okta",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "okta",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.DB,
|
||||
},
|
||||
}
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -164,55 +155,41 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
}
|
||||
},
|
||||
identity: defaultIdentity,
|
||||
want: []*models.SSOSetting{
|
||||
want: []*models.SSOSettings{
|
||||
{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
},
|
||||
{
|
||||
Provider: "okta",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "okta",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.DB,
|
||||
},
|
||||
{
|
||||
Provider: "gitlab",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "gitlab",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "generic_oauth",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "generic_oauth",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "google",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "google",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "azuread",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "grafana_com",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "grafana_com",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@ -220,20 +197,16 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
{
|
||||
name: "should return the settings that the user has access to",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSetting{
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSettings{
|
||||
{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
},
|
||||
{
|
||||
Provider: "okta",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "okta",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
},
|
||||
}
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
@ -242,20 +215,16 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
}
|
||||
},
|
||||
identity: scopedIdentity,
|
||||
want: []*models.SSOSetting{
|
||||
want: []*models.SSOSettings{
|
||||
{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": true,
|
||||
},
|
||||
Source: models.DB,
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: true},
|
||||
Source: models.DB,
|
||||
},
|
||||
{
|
||||
Provider: "azuread",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@ -270,62 +239,48 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
{
|
||||
name: "should use the fallback strategy if store returns empty list",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSetting{}
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSettings{}
|
||||
env.fallbackStrategy.ExpectedIsMatch = true
|
||||
env.fallbackStrategy.ExpectedConfig = map[string]interface{}{
|
||||
"enabled": false,
|
||||
}
|
||||
},
|
||||
identity: defaultIdentity,
|
||||
want: []*models.SSOSetting{
|
||||
want: []*models.SSOSettings{
|
||||
{
|
||||
Provider: "github",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "github",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "okta",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "okta",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "gitlab",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "gitlab",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "generic_oauth",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "generic_oauth",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "google",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "google",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "azuread",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
{
|
||||
Provider: "grafana_com",
|
||||
Settings: map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
Source: models.System,
|
||||
Provider: "grafana_com",
|
||||
OAuthSettings: &social.OAuthInfo{Enabled: false},
|
||||
Source: models.System,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@ -333,7 +288,7 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
{
|
||||
name: "should return error if any of the fallback strategies was not found",
|
||||
setup: func(env testEnv) {
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSetting{}
|
||||
env.store.ExpectedSSOSettings = []*models.SSOSettings{}
|
||||
env.fallbackStrategy.ExpectedIsMatch = false
|
||||
},
|
||||
identity: defaultIdentity,
|
||||
|
@ -33,19 +33,19 @@ func (_m *MockService) Delete(ctx context.Context, provider string) error {
|
||||
}
|
||||
|
||||
// GetForProvider provides a mock function with given fields: ctx, provider
|
||||
func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
ret := _m.Called(ctx, provider)
|
||||
|
||||
var r0 *models.SSOSetting
|
||||
var r0 *models.SSOSettings
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok {
|
||||
return rf(ctx, provider)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok {
|
||||
r0 = rf(ctx, provider)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*models.SSOSetting)
|
||||
r0 = ret.Get(0).(*models.SSOSettings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,19 +59,19 @@ func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*mo
|
||||
}
|
||||
|
||||
// List provides a mock function with given fields: ctx, requester
|
||||
func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) {
|
||||
func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) {
|
||||
ret := _m.Called(ctx, requester)
|
||||
|
||||
var r0 []*models.SSOSetting
|
||||
var r0 []*models.SSOSettings
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSetting, error)); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSettings, error)); ok {
|
||||
return rf(ctx, requester)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSetting); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSettings); ok {
|
||||
r0 = rf(ctx, requester)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*models.SSOSetting)
|
||||
r0 = ret.Get(0).([]*models.SSOSettings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,13 +108,13 @@ func (_m *MockService) Reload(ctx context.Context, provider string) {
|
||||
_m.Called(ctx, provider)
|
||||
}
|
||||
|
||||
// Upsert provides a mock function with given fields: ctx, provider, data
|
||||
func (_m *MockService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
|
||||
ret := _m.Called(ctx, provider, data)
|
||||
// Upsert provides a mock function with given fields: ctx, settings
|
||||
func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
ret := _m.Called(ctx, settings)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok {
|
||||
r0 = rf(ctx, provider, data)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
||||
r0 = rf(ctx, settings)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
var _ ssosettings.Store = (*FakeStore)(nil)
|
||||
|
||||
type FakeStore struct {
|
||||
ExpectedSSOSetting *models.SSOSetting
|
||||
ExpectedSSOSettings []*models.SSOSetting
|
||||
ExpectedSSOSetting *models.SSOSettings
|
||||
ExpectedSSOSettings []*models.SSOSettings
|
||||
ExpectedError error
|
||||
}
|
||||
|
||||
@ -19,15 +19,15 @@ func NewFakeStore() *FakeStore {
|
||||
return &FakeStore{}
|
||||
}
|
||||
|
||||
func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
return f.ExpectedSSOSetting, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSetting, error) {
|
||||
func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
||||
return f.ExpectedSSOSettings, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
|
||||
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Code generated by mockery v2.27.1. DO NOT EDIT.
|
||||
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||
|
||||
package ssosettingstests
|
||||
|
||||
@ -29,19 +29,19 @@ func (_m *MockStore) Delete(ctx context.Context, provider string) error {
|
||||
}
|
||||
|
||||
// Get provides a mock function with given fields: ctx, provider
|
||||
func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) {
|
||||
func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) {
|
||||
ret := _m.Called(ctx, provider)
|
||||
|
||||
var r0 *models.SSOSetting
|
||||
var r0 *models.SSOSettings
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok {
|
||||
return rf(ctx, provider)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok {
|
||||
r0 = rf(ctx, provider)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*models.SSOSetting)
|
||||
r0 = ret.Get(0).(*models.SSOSettings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,20 +54,20 @@ func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetti
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetAll provides a mock function with given fields: ctx
|
||||
func (_m *MockStore) GetAll(ctx context.Context) ([]*models.SSOSetting, error) {
|
||||
// List provides a mock function with given fields: ctx
|
||||
func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 []*models.SSOSetting
|
||||
var r0 []*models.SSOSettings
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSetting, error)); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSettings, error)); ok {
|
||||
return rf(ctx)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSetting); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSettings); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*models.SSOSetting)
|
||||
r0 = ret.Get(0).([]*models.SSOSettings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,13 +94,13 @@ func (_m *MockStore) Patch(ctx context.Context, provider string, data map[string
|
||||
return r0
|
||||
}
|
||||
|
||||
// Upsert provides a mock function with given fields: ctx, provider, data
|
||||
func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error {
|
||||
ret := _m.Called(ctx, provider, data)
|
||||
// Upsert provides a mock function with given fields: ctx, settings
|
||||
func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
ret := _m.Called(ctx, settings)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok {
|
||||
r0 = rf(ctx, provider, data)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
||||
r0 = rf(ctx, settings)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@ -108,13 +108,12 @@ func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[strin
|
||||
return r0
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockStore interface {
|
||||
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockStore(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore {
|
||||
}) *MockStore {
|
||||
mock := &MockStore{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user