Alerting: Provisioning API respects global rule quota (#52180)

* Inject interface for quota service and create mock

* Check quota and return 403 if limit exceeded

* Implement tests for quota being exceeded
This commit is contained in:
Alexander Weaver 2022-07-13 17:36:17 -05:00 committed by GitHub
parent eb5a96eae9
commit 2d7389c34d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 209 additions and 22 deletions

View File

@ -52,7 +52,7 @@ type MuteTimingService interface {
type AlertRuleService interface {
GetAlertRule(ctx context.Context, orgID int64, ruleUID string) (alerting_models.AlertRule, alerting_models.Provenance, error)
CreateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance) (alerting_models.AlertRule, error)
CreateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance, userID int64) (alerting_models.AlertRule, error)
UpdateAlertRule(ctx context.Context, rule alerting_models.AlertRule, provenance alerting_models.Provenance) (alerting_models.AlertRule, error)
DeleteAlertRule(ctx context.Context, orgID int64, ruleUID string, provenance alerting_models.Provenance) error
GetRuleGroup(ctx context.Context, orgID int64, folder, group string) (definitions.AlertRuleGroup, error)
@ -254,7 +254,7 @@ func (srv *ProvisioningSrv) RouteRouteGetAlertRule(c *models.ReqContext, UID str
}
func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definitions.ProvisionedAlertRule) response.Response {
createdAlertRule, err := srv.alertRules.CreateAlertRule(c.Req.Context(), ar.UpstreamModel(), alerting_models.ProvenanceAPI)
createdAlertRule, err := srv.alertRules.CreateAlertRule(c.Req.Context(), ar.UpstreamModel(), alerting_models.ProvenanceAPI, c.UserId)
if errors.Is(err, alerting_models.ErrAlertRuleFailedValidation) {
return ErrResp(http.StatusBadRequest, err, "")
}
@ -262,6 +262,9 @@ func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definiti
if errors.Is(err, store.ErrOptimisticLock) {
return ErrResp(http.StatusConflict, err, "")
}
if errors.Is(err, alerting_models.ErrQuotaReached) {
return ErrResp(http.StatusForbidden, err, "")
}
return ErrResp(http.StatusInternalServerError, err, "")
}
ar.ID = createdAlertRule.ID

View File

@ -15,7 +15,8 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/provisioning"
"github.com/grafana/grafana/pkg/services/ngalert/store"
secrets "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/secrets"
secrets_fakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/web"
prometheus "github.com/prometheus/alertmanager/config"
@ -259,6 +260,20 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 404, response.Status())
})
t.Run("have reached the rule quota, POST returns 403", func(t *testing.T) {
env := createTestEnv(t)
quotas := provisioning.MockQuotaChecker{}
quotas.EXPECT().LimitExceeded()
env.quotas = &quotas
sut := createProvisioningSrvSutFromEnv(t, &env)
rule := createTestAlertRule("rule", 1)
rc := createTestRequestCtx()
response := sut.RoutePostAlertRule(&rc, rule)
require.Equal(t, 403, response.Status())
})
})
t.Run("alert rule groups", func(t *testing.T) {
@ -284,9 +299,21 @@ func TestProvisioningApi(t *testing.T) {
})
}
func createProvisioningSrvSut(t *testing.T) ProvisioningSrv {
// testEnvironment binds together common dependencies for testing alerting APIs.
type testEnvironment struct {
secrets secrets.Service
log log.Logger
store store.DBstore
configs provisioning.AMConfigStore
xact provisioning.TransactionManager
quotas provisioning.QuotaChecker
prov provisioning.ProvisioningStore
}
func createTestEnv(t *testing.T) testEnvironment {
t.Helper()
secrets := secrets.NewFakeSecretsService()
secrets := secrets_fakes.NewFakeSecretsService()
log := log.NewNopLogger()
configs := &provisioning.MockAMConfigStore{}
configs.EXPECT().
@ -298,18 +325,41 @@ func createProvisioningSrvSut(t *testing.T) ProvisioningSrv {
SQLStore: sqlStore,
BaseInterval: time.Second * 10,
}
quotas := &provisioning.MockQuotaChecker{}
quotas.EXPECT().LimitOK()
xact := &provisioning.NopTransactionManager{}
prov := &provisioning.MockProvisioningStore{}
prov.EXPECT().SaveSucceeds()
prov.EXPECT().GetReturns(models.ProvenanceNone)
return testEnvironment{
secrets: secrets,
log: log,
configs: configs,
store: store,
xact: xact,
prov: prov,
quotas: quotas,
}
}
func createProvisioningSrvSut(t *testing.T) ProvisioningSrv {
t.Helper()
env := createTestEnv(t)
return createProvisioningSrvSutFromEnv(t, &env)
}
func createProvisioningSrvSutFromEnv(t *testing.T, env *testEnvironment) ProvisioningSrv {
t.Helper()
return ProvisioningSrv{
log: log,
log: env.log,
policies: newFakeNotificationPolicyService(),
contactPointService: provisioning.NewContactPointService(configs, secrets, prov, xact, log),
templates: provisioning.NewTemplateService(configs, prov, xact, log),
muteTimings: provisioning.NewMuteTimingService(configs, prov, xact, log),
alertRules: provisioning.NewAlertRuleService(store, prov, xact, 60, 10, log),
contactPointService: provisioning.NewContactPointService(env.configs, env.secrets, env.prov, env.xact, env.log),
templates: provisioning.NewTemplateService(env.configs, env.prov, env.xact, env.log),
muteTimings: provisioning.NewMuteTimingService(env.configs, env.prov, env.xact, env.log),
alertRules: provisioning.NewAlertRuleService(env.store, env.prov, env.quotas, env.xact, 60, 10, env.log),
}
}

View File

@ -42,7 +42,6 @@ type RulerSrv struct {
}
var (
errQuotaReached = errors.New("quota has been exceeded")
errProvisionedResource = errors.New("request affects resources created via provisioning API")
)
@ -401,7 +400,7 @@ func (srv RulerSrv) updateAlertRulesInGroup(c *models.ReqContext, groupKey ngmod
return fmt.Errorf("failed to get alert rules quota: %w", err)
}
if limitReached {
return errQuotaReached
return ngmodels.ErrQuotaReached
}
}
return nil
@ -412,7 +411,7 @@ func (srv RulerSrv) updateAlertRulesInGroup(c *models.ReqContext, groupKey ngmod
return ErrResp(http.StatusNotFound, err, "failed to update rule group")
} else if errors.Is(err, ngmodels.ErrAlertRuleFailedValidation) || errors.Is(err, errProvisionedResource) {
return ErrResp(http.StatusBadRequest, err, "failed to update rule group")
} else if errors.Is(err, errQuotaReached) {
} else if errors.Is(err, ngmodels.ErrQuotaReached) {
return ErrResp(http.StatusForbidden, err, "")
} else if errors.Is(err, ErrAuthorization) {
return ErrResp(http.StatusUnauthorized, err, "")

View File

@ -23,6 +23,7 @@ var (
ErrRuleGroupNamespaceNotFound = errors.New("rule group not found under this namespace")
ErrAlertRuleFailedValidation = errors.New("invalid alert rule")
ErrAlertRuleUniqueConstraintViolation = errors.New("a conflicting alert rule is found: rule title under the same organisation and folder should be unique")
ErrQuotaReached = errors.New("quota has been exceeded")
)
// swagger:enum NoDataState

View File

@ -170,7 +170,7 @@ func (ng *AlertNG) init() error {
contactPointService := provisioning.NewContactPointService(store, ng.SecretsService, store, store, ng.Log)
templateService := provisioning.NewTemplateService(store, store, store, ng.Log)
muteTimingService := provisioning.NewMuteTimingService(store, store, store, ng.Log)
alertRuleService := provisioning.NewAlertRuleService(store, store, store,
alertRuleService := provisioning.NewAlertRuleService(store, store, ng.QuotaService, store,
int64(ng.Cfg.UnifiedAlerting.DefaultRuleEvaluationInterval.Seconds()),
int64(ng.Cfg.UnifiedAlerting.BaseInterval.Seconds()), ng.Log)

View File

@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
"github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/util"
)
@ -18,12 +19,14 @@ type AlertRuleService struct {
baseIntervalSeconds int64
ruleStore RuleStore
provenanceStore ProvisioningStore
quotas QuotaChecker
xact TransactionManager
log log.Logger
}
func NewAlertRuleService(ruleStore RuleStore,
provenanceStore ProvisioningStore,
quotas QuotaChecker,
xact TransactionManager,
defaultIntervalSeconds int64,
baseIntervalSeconds int64,
@ -33,6 +36,7 @@ func NewAlertRuleService(ruleStore RuleStore,
baseIntervalSeconds: baseIntervalSeconds,
ruleStore: ruleStore,
provenanceStore: provenanceStore,
quotas: quotas,
xact: xact,
log: log,
}
@ -57,7 +61,7 @@ func (service *AlertRuleService) GetAlertRule(ctx context.Context, orgID int64,
// CreateAlertRule creates a new alert rule. This function will ignore any
// interval that is set in the rule struct and use the already existing group
// interval or the default one.
func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule models.AlertRule, provenance models.Provenance) (models.AlertRule, error) {
func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule models.AlertRule, provenance models.Provenance, userID int64) (models.AlertRule, error) {
if rule.UID == "" {
rule.UID = util.GenerateShortUID()
}
@ -82,6 +86,18 @@ func (service *AlertRuleService) CreateAlertRule(ctx context.Context, rule model
} else {
return errors.New("couldn't find newly created id")
}
limitReached, err := service.quotas.CheckQuotaReached(ctx, "alert_rule", &quota.ScopeParameters{
OrgId: rule.OrgID,
UserId: userID,
})
if err != nil {
return fmt.Errorf("failed to check alert rule quota: %w", err)
}
if limitReached {
return models.ErrQuotaReached
}
return service.provenanceStore.SetProvenance(ctx, &rule, rule.OrgID, provenance)
})
if err != nil {

View File

@ -15,26 +15,29 @@ import (
func TestAlertRuleService(t *testing.T) {
ruleService := createAlertRuleService(t)
t.Run("alert rule creation should return the created id", func(t *testing.T) {
var orgID int64 = 1
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", orgID), models.ProvenanceNone)
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", orgID), models.ProvenanceNone, 0)
require.NoError(t, err)
require.NotEqual(t, 0, rule.ID, "expected to get the created id and not the zero value")
})
t.Run("alert rule creation should set the right provenance", func(t *testing.T) {
var orgID int64 = 1
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#2", orgID), models.ProvenanceAPI)
rule, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#2", orgID), models.ProvenanceAPI, 0)
require.NoError(t, err)
_, provenance, err := ruleService.GetAlertRule(context.Background(), orgID, rule.UID)
require.NoError(t, err)
require.Equal(t, models.ProvenanceAPI, provenance)
})
t.Run("alert rule group should be updated correctly", func(t *testing.T) {
var orgID int64 = 1
rule := dummyRule("test#3", orgID)
rule.RuleGroup = "a"
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone)
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err)
require.Equal(t, int64(60), rule.IntervalSeconds)
@ -46,11 +49,12 @@ func TestAlertRuleService(t *testing.T) {
require.NoError(t, err)
require.Equal(t, interval, rule.IntervalSeconds)
})
t.Run("alert rule should get interval from existing rule group", func(t *testing.T) {
var orgID int64 = 1
rule := dummyRule("test#4", orgID)
rule.RuleGroup = "b"
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone)
rule, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err)
var interval int64 = 120
@ -59,10 +63,11 @@ func TestAlertRuleService(t *testing.T) {
rule = dummyRule("test#4-1", orgID)
rule.RuleGroup = "b"
rule, err = ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone)
rule, err = ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err)
require.Equal(t, interval, rule.IntervalSeconds)
})
t.Run("updating a rule group should bump the version number", func(t *testing.T) {
const (
orgID = 123
@ -75,7 +80,7 @@ func TestAlertRuleService(t *testing.T) {
rule.UID = ruleUID
rule.RuleGroup = ruleGroup
rule.NamespaceUID = namespaceUID
_, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone)
_, err := ruleService.CreateAlertRule(context.Background(), rule, models.ProvenanceNone, 0)
require.NoError(t, err)
rule, _, err = ruleService.GetAlertRule(context.Background(), orgID, ruleUID)
@ -91,6 +96,7 @@ func TestAlertRuleService(t *testing.T) {
require.Equal(t, int64(2), rule.Version)
require.Equal(t, newInterval, rule.IntervalSeconds)
})
t.Run("alert rule provenace should be correctly checked", func(t *testing.T) {
tests := []struct {
name string
@ -139,7 +145,7 @@ func TestAlertRuleService(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
var orgID int64 = 1
rule := dummyRule(t.Name(), orgID)
rule, err := ruleService.CreateAlertRule(context.Background(), rule, test.from)
rule, err := ruleService.CreateAlertRule(context.Background(), rule, test.from, 0)
require.NoError(t, err)
_, err = ruleService.UpdateAlertRule(context.Background(), rule, test.to)
@ -151,6 +157,17 @@ func TestAlertRuleService(t *testing.T) {
})
}
})
t.Run("quota met causes create to be rejected", func(t *testing.T) {
ruleService := createAlertRuleService(t)
checker := &MockQuotaChecker{}
checker.EXPECT().LimitExceeded()
ruleService.quotas = checker
_, err := ruleService.CreateAlertRule(context.Background(), dummyRule("test#1", 1), models.ProvenanceNone, 0)
require.ErrorIs(t, err, models.ErrQuotaReached)
})
}
func createAlertRuleService(t *testing.T) AlertRuleService {
@ -160,9 +177,12 @@ func createAlertRuleService(t *testing.T) AlertRuleService {
SQLStore: sqlStore,
BaseInterval: time.Second * 10,
}
quotas := MockQuotaChecker{}
quotas.EXPECT().LimitOK()
return AlertRuleService{
ruleStore: store,
provenanceStore: store,
quotas: &quotas,
xact: sqlStore,
log: log.New("testing"),
baseIntervalSeconds: 10,

View File

@ -5,6 +5,7 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/services/quota"
)
// AMStore is a store of Alertmanager configurations.
@ -37,3 +38,9 @@ type RuleStore interface {
UpdateAlertRules(ctx context.Context, rule []store.UpdateRule) error
DeleteAlertRulesByUID(ctx context.Context, orgID int64, ruleUID ...string) error
}
// QuotaChecker represents the ability to evaluate whether quotas are met.
//go:generate mockery --name QuotaChecker --structname MockQuotaChecker --inpackage --filename quota_checker_mock.go --with-expecter
type QuotaChecker interface {
CheckQuotaReached(ctx context.Context, target string, scopeParams *quota.ScopeParameters) (bool, error)
}

View File

@ -0,0 +1,81 @@
// Code generated by mockery v2.12.0. DO NOT EDIT.
package provisioning
import (
context "context"
quota "github.com/grafana/grafana/pkg/services/quota"
mock "github.com/stretchr/testify/mock"
testing "testing"
)
// MockQuotaChecker is an autogenerated mock type for the QuotaChecker type
type MockQuotaChecker struct {
mock.Mock
}
type MockQuotaChecker_Expecter struct {
mock *mock.Mock
}
func (_m *MockQuotaChecker) EXPECT() *MockQuotaChecker_Expecter {
return &MockQuotaChecker_Expecter{mock: &_m.Mock}
}
// CheckQuotaReached provides a mock function with given fields: ctx, target, scopeParams
func (_m *MockQuotaChecker) CheckQuotaReached(ctx context.Context, target string, scopeParams *quota.ScopeParameters) (bool, error) {
ret := _m.Called(ctx, target, scopeParams)
var r0 bool
if rf, ok := ret.Get(0).(func(context.Context, string, *quota.ScopeParameters) bool); ok {
r0 = rf(ctx, target, scopeParams)
} else {
r0 = ret.Get(0).(bool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, *quota.ScopeParameters) error); ok {
r1 = rf(ctx, target, scopeParams)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQuotaChecker_CheckQuotaReached_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQuotaReached'
type MockQuotaChecker_CheckQuotaReached_Call struct {
*mock.Call
}
// CheckQuotaReached is a helper method to define mock.On call
// - ctx context.Context
// - target string
// - scopeParams *quota.ScopeParameters
func (_e *MockQuotaChecker_Expecter) CheckQuotaReached(ctx interface{}, target interface{}, scopeParams interface{}) *MockQuotaChecker_CheckQuotaReached_Call {
return &MockQuotaChecker_CheckQuotaReached_Call{Call: _e.mock.On("CheckQuotaReached", ctx, target, scopeParams)}
}
func (_c *MockQuotaChecker_CheckQuotaReached_Call) Run(run func(ctx context.Context, target string, scopeParams *quota.ScopeParameters)) *MockQuotaChecker_CheckQuotaReached_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*quota.ScopeParameters))
})
return _c
}
func (_c *MockQuotaChecker_CheckQuotaReached_Call) Return(_a0 bool, _a1 error) *MockQuotaChecker_CheckQuotaReached_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// NewMockQuotaChecker creates a new instance of MockQuotaChecker. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockQuotaChecker(t testing.TB) *MockQuotaChecker {
mock := &MockQuotaChecker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -170,3 +170,13 @@ func (m *MockProvisioningStore_Expecter) SaveSucceeds() *MockProvisioningStore_E
m.DeleteProvenance(mock.Anything, mock.Anything, mock.Anything).Return(nil)
return m
}
func (m *MockQuotaChecker_Expecter) LimitOK() *MockQuotaChecker_Expecter {
m.CheckQuotaReached(mock.Anything, mock.Anything, mock.Anything).Return(false, nil)
return m
}
func (m *MockQuotaChecker_Expecter) LimitExceeded() *MockQuotaChecker_Expecter {
m.CheckQuotaReached(mock.Anything, mock.Anything, mock.Anything).Return(true, nil)
return m
}