Alerting: Update notification policy service to check provenance status (#94359)

* update ResetPolicyTree to accept provenance status

* update methods to check for provenance status use relaxed validation
This commit is contained in:
Yuri Tseretyan 2024-10-10 16:26:30 -04:00 committed by GitHub
parent 75d42d82a3
commit 27c44f4709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 115 additions and 33 deletions

View File

@ -53,7 +53,7 @@ type TemplateService interface {
type NotificationPolicyService interface { type NotificationPolicyService interface {
GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error)
UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p alerting_models.Provenance, version string) error UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p alerting_models.Provenance, version string) error
ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) ResetPolicyTree(ctx context.Context, orgID int64, provenance alerting_models.Provenance) (definitions.Route, error)
} }
type MuteTimingService interface { type MuteTimingService interface {
@ -84,7 +84,7 @@ func (srv *ProvisioningSrv) RouteGetPolicyTree(c *contextmodel.ReqContext) respo
return ErrResp(http.StatusNotFound, err, "") return ErrResp(http.StatusNotFound, err, "")
} }
if err != nil { if err != nil {
return ErrResp(http.StatusInternalServerError, err, "") return response.ErrOrFallback(http.StatusInternalServerError, "failed to get notification policy tree", err)
} }
return response.JSON(http.StatusOK, policies) return response.JSON(http.StatusOK, policies)
@ -117,16 +117,17 @@ func (srv *ProvisioningSrv) RoutePutPolicyTree(c *contextmodel.ReqContext, tree
return ErrResp(http.StatusBadRequest, err, "") return ErrResp(http.StatusBadRequest, err, "")
} }
if err != nil { if err != nil {
return response.ErrOrFallback(http.StatusInternalServerError, "", err) return response.ErrOrFallback(http.StatusInternalServerError, "failed to update notification policy tree", err)
} }
return response.JSON(http.StatusAccepted, util.DynMap{"message": "policies updated"}) return response.JSON(http.StatusAccepted, util.DynMap{"message": "policies updated"})
} }
func (srv *ProvisioningSrv) RouteResetPolicyTree(c *contextmodel.ReqContext) response.Response { func (srv *ProvisioningSrv) RouteResetPolicyTree(c *contextmodel.ReqContext) response.Response {
tree, err := srv.policies.ResetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID()) provenance := determineProvenance(c)
tree, err := srv.policies.ResetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID(), alerting_models.Provenance(provenance))
if err != nil { if err != nil {
return ErrResp(http.StatusInternalServerError, err, "") return response.ErrOrFallback(http.StatusInternalServerError, "failed to reset notification policy tree", err)
} }
return response.JSON(http.StatusAccepted, tree) return response.JSON(http.StatusAccepted, tree)
} }

View File

@ -140,7 +140,6 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 500, response.Status()) require.Equal(t, 500, response.Status())
require.NotEmpty(t, response.Body()) require.NotEmpty(t, response.Body())
require.Contains(t, string(response.Body()), "something went wrong")
}) })
t.Run("PUT returns 500", func(t *testing.T) { t.Run("PUT returns 500", func(t *testing.T) {
@ -164,7 +163,6 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 500, response.Status()) require.Equal(t, 500, response.Status())
require.NotEmpty(t, response.Body()) require.NotEmpty(t, response.Body())
require.Contains(t, string(response.Body()), "something went wrong")
}) })
}) })
}) })
@ -2002,7 +2000,7 @@ func (f *fakeNotificationPolicyService) UpdatePolicyTree(ctx context.Context, or
return nil return nil
} }
func (f *fakeNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64, provenance models.Provenance) (definitions.Route, error) {
f.tree = definitions.Route{} // TODO f.tree = definitions.Route{} // TODO
return f.tree, nil return f.tree, nil
} }
@ -2017,7 +2015,7 @@ func (f *fakeFailingNotificationPolicyService) UpdatePolicyTree(ctx context.Cont
return fmt.Errorf("something went wrong") return fmt.Errorf("something went wrong")
} }
func (f *fakeFailingNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeFailingNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64, provenance models.Provenance) (definitions.Route, error) {
return definitions.Route{}, fmt.Errorf("something went wrong") return definitions.Route{}, fmt.Errorf("something went wrong")
} }
@ -2031,7 +2029,7 @@ func (f *fakeRejectingNotificationPolicyService) UpdatePolicyTree(ctx context.Co
return fmt.Errorf("%w: invalid policy tree", provisioning.ErrValidation) return fmt.Errorf("%w: invalid policy tree", provisioning.ErrValidation)
} }
func (f *fakeRejectingNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeRejectingNotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64, provenance models.Provenance) (definitions.Route, error) {
return definitions.Route{}, nil return definitions.Route{}, nil
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions" "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
"github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/notifier/legacy_storage" "github.com/grafana/grafana/pkg/services/ngalert/notifier/legacy_storage"
"github.com/grafana/grafana/pkg/services/ngalert/provisioning/validation"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
@ -25,6 +26,7 @@ type NotificationPolicyService struct {
xact TransactionManager xact TransactionManager
log log.Logger log log.Logger
settings setting.UnifiedAlertingSettings settings setting.UnifiedAlertingSettings
validator validation.ProvenanceStatusTransitionValidator
} }
func NewNotificationPolicyService(am alertmanagerConfigStore, prov ProvisioningStore, func NewNotificationPolicyService(am alertmanagerConfigStore, prov ProvisioningStore,
@ -35,6 +37,7 @@ func NewNotificationPolicyService(am alertmanagerConfigStore, prov ProvisioningS
xact: xact, xact: xact,
log: log, log: log,
settings: settings, settings: settings,
validator: validation.ValidateProvenanceRelaxed,
} }
} }
@ -69,11 +72,20 @@ func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgI
return err return err
} }
err = nps.checkOptimisticConcurrency(*revision.Config.AlertmanagerConfig.Route, models.Provenance(tree.Provenance), version, "update") err = nps.checkOptimisticConcurrency(*revision.Config.AlertmanagerConfig.Route, p, version, "update")
if err != nil { if err != nil {
return err return err
} }
// check that provenance is not changed in an invalid way
storedProvenance, err := nps.provenanceStore.GetProvenance(ctx, &tree, orgID)
if err != nil {
return err
}
if err := nps.validator(storedProvenance, p); err != nil {
return err
}
receivers, err := nps.receiversToMap(revision.Config.AlertmanagerConfig.Receivers) receivers, err := nps.receiversToMap(revision.Config.AlertmanagerConfig.Receivers)
if err != nil { if err != nil {
return err return err
@ -107,7 +119,15 @@ func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgI
}) })
} }
func (nps *NotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (nps *NotificationPolicyService) ResetPolicyTree(ctx context.Context, orgID int64, provenance models.Provenance) (definitions.Route, error) {
storedProvenance, err := nps.provenanceStore.GetProvenance(ctx, &definitions.Route{}, orgID)
if err != nil {
return definitions.Route{}, err
}
if err := nps.validator(storedProvenance, provenance); err != nil {
return definitions.Route{}, err
}
defaultCfg, err := legacy_storage.DeserializeAlertmanagerConfig([]byte(nps.settings.DefaultConfiguration)) defaultCfg, err := legacy_storage.DeserializeAlertmanagerConfig([]byte(nps.settings.DefaultConfiguration))
if err != nil { if err != nil {
nps.log.Error("Failed to parse default alertmanager config: %w", err) nps.log.Error("Failed to parse default alertmanager config: %w", err)

View File

@ -2,6 +2,7 @@ package provisioning
import ( import (
"context" "context"
"errors"
"testing" "testing"
"github.com/grafana/alerting/definition" "github.com/grafana/alerting/definition"
@ -102,7 +103,10 @@ func TestUpdatePolicyTree(t *testing.T) {
t.Run("ErrValidation if referenced receiver does not exist", func(t *testing.T) { t.Run("ErrValidation if referenced receiver does not exist", func(t *testing.T) {
rev := getDefaultConfigRevision() rev := getDefaultConfigRevision()
sut, store, _ := createNotificationPolicyServiceSut() sut, store, prov := createNotificationPolicyServiceSut()
prov.GetProvenanceFunc = func(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) {
return models.ProvenanceNone, nil
}
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) { store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil return &rev, nil
} }
@ -137,6 +141,35 @@ func TestUpdatePolicyTree(t *testing.T) {
require.ErrorIs(t, err, ErrVersionConflict) require.ErrorIs(t, err, ErrVersionConflict)
}) })
t.Run("Error if provenance validation fails", func(t *testing.T) {
sut, store, prov := createNotificationPolicyServiceSut()
prov.GetProvenanceFunc = func(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) {
return models.ProvenanceAPI, nil
}
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
}
expectedRev := getDefaultConfigRevision()
route := newRoute
expectedRev.ConcurrencyToken = rev.ConcurrencyToken
expectedRev.Config.AlertmanagerConfig.Route = &route
expectedErr := errors.New("test")
sut.validator = func(from, to models.Provenance) error {
assert.Equal(t, models.ProvenanceAPI, from)
assert.Equal(t, models.ProvenanceNone, to)
return expectedErr
}
err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, defaultVersion)
require.ErrorIs(t, err, expectedErr)
assert.Len(t, prov.Calls, 1)
assert.Equal(t, "GetProvenance", prov.Calls[0].MethodName)
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2].(int64))
})
t.Run("updates Route and sets provenance in transaction if route is valid and version matches", func(t *testing.T) { t.Run("updates Route and sets provenance in transaction if route is valid and version matches", func(t *testing.T) {
sut, store, prov := createNotificationPolicyServiceSut() sut, store, prov := createNotificationPolicyServiceSut()
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) { store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
@ -155,12 +188,16 @@ func TestUpdatePolicyTree(t *testing.T) {
assertInTransaction(t, store.Calls[1].Args[0].(context.Context)) assertInTransaction(t, store.Calls[1].Args[0].(context.Context))
assert.Equal(t, &expectedRev, store.Calls[1].Args[1]) assert.Equal(t, &expectedRev, store.Calls[1].Args[1])
assert.Len(t, prov.Calls, 1) c := prov.Calls[0]
assert.Equal(t, "SetProvenance", prov.Calls[0].MethodName) assert.Equal(t, "GetProvenance", c.MethodName)
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context)) assert.IsType(t, &definitions.Route{}, c.Arguments[1])
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1]) assert.Equal(t, orgID, c.Arguments[2].(int64))
assert.Equal(t, orgID, prov.Calls[0].Arguments[2].(int64)) c = prov.Calls[1]
assert.Equal(t, models.ProvenanceAPI, prov.Calls[0].Arguments[3].(models.Provenance)) assert.Equal(t, "SetProvenance", c.MethodName)
assertInTransaction(t, c.Arguments[0].(context.Context))
assert.IsType(t, &definitions.Route{}, c.Arguments[1])
assert.Equal(t, orgID, c.Arguments[2].(int64))
assert.Equal(t, models.ProvenanceAPI, c.Arguments[3].(models.Provenance))
}) })
t.Run("bypasses optimistic concurrency if provided version is empty", func(t *testing.T) { t.Run("bypasses optimistic concurrency if provided version is empty", func(t *testing.T) {
@ -181,12 +218,13 @@ func TestUpdatePolicyTree(t *testing.T) {
assertInTransaction(t, store.Calls[1].Args[0].(context.Context)) assertInTransaction(t, store.Calls[1].Args[0].(context.Context))
assert.Equal(t, &expectedRev, store.Calls[1].Args[1]) assert.Equal(t, &expectedRev, store.Calls[1].Args[1])
assert.Len(t, prov.Calls, 1) assert.Len(t, prov.Calls, 2)
assert.Equal(t, "SetProvenance", prov.Calls[0].MethodName) c := prov.Calls[1]
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context)) assert.Equal(t, "SetProvenance", c.MethodName)
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1]) assertInTransaction(t, c.Arguments[0].(context.Context))
assert.Equal(t, orgID, prov.Calls[0].Arguments[2].(int64)) assert.IsType(t, &definitions.Route{}, c.Arguments[1])
assert.Equal(t, models.ProvenanceAPI, prov.Calls[0].Arguments[3].(models.Provenance)) assert.Equal(t, orgID, c.Arguments[2].(int64))
assert.Equal(t, models.ProvenanceAPI, c.Arguments[3].(models.Provenance))
}) })
} }
@ -223,10 +261,27 @@ func TestResetPolicyTree(t *testing.T) {
sut.settings = setting.UnifiedAlertingSettings{ sut.settings = setting.UnifiedAlertingSettings{
DefaultConfiguration: "{", DefaultConfiguration: "{",
} }
_, err := sut.ResetPolicyTree(context.Background(), orgID) _, err := sut.ResetPolicyTree(context.Background(), orgID, models.ProvenanceNone)
require.ErrorContains(t, err, "failed to parse default alertmanager config") require.ErrorContains(t, err, "failed to parse default alertmanager config")
}) })
t.Run("Error if provenance validation fails", func(t *testing.T) {
sut, _, prov := createNotificationPolicyServiceSut()
prov.GetProvenanceFunc = func(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) {
return models.ProvenanceAPI, nil
}
expectedErr := errors.New("test")
sut.validator = func(from, to models.Provenance) error {
assert.Equal(t, models.ProvenanceAPI, from)
assert.Equal(t, models.ProvenanceNone, to)
return expectedErr
}
_, err := sut.ResetPolicyTree(context.Background(), orgID, models.ProvenanceNone)
require.ErrorIs(t, err, expectedErr)
})
t.Run("replaces route with one from the default config and copies receivers if do not exist", func(t *testing.T) { t.Run("replaces route with one from the default config and copies receivers if do not exist", func(t *testing.T) {
defaultConfig := getDefaultConfigRevision().Config defaultConfig := getDefaultConfigRevision().Config
data, err := legacy_storage.SerializeAlertmanagerConfig(*defaultConfig) data, err := legacy_storage.SerializeAlertmanagerConfig(*defaultConfig)
@ -252,7 +307,7 @@ func TestResetPolicyTree(t *testing.T) {
expectedRev.Config.AlertmanagerConfig.Route = getDefaultConfigRevision().Config.AlertmanagerConfig.Route expectedRev.Config.AlertmanagerConfig.Route = getDefaultConfigRevision().Config.AlertmanagerConfig.Route
expectedRev.Config.AlertmanagerConfig.Receivers = append(expectedRev.Config.AlertmanagerConfig.Receivers, getDefaultConfigRevision().Config.AlertmanagerConfig.Receivers[0]) expectedRev.Config.AlertmanagerConfig.Receivers = append(expectedRev.Config.AlertmanagerConfig.Receivers, getDefaultConfigRevision().Config.AlertmanagerConfig.Receivers[0])
tree, err := sut.ResetPolicyTree(context.Background(), orgID) tree, err := sut.ResetPolicyTree(context.Background(), orgID, models.ProvenanceNone)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *defaultConfig.AlertmanagerConfig.Route, tree) assert.Equal(t, *defaultConfig.AlertmanagerConfig.Route, tree)
@ -262,11 +317,16 @@ func TestResetPolicyTree(t *testing.T) {
resetRev := store.Calls[1].Args[1].(*legacy_storage.ConfigRevision) resetRev := store.Calls[1].Args[1].(*legacy_storage.ConfigRevision)
assert.Equal(t, expectedRev.Config.AlertmanagerConfig, resetRev.Config.AlertmanagerConfig) assert.Equal(t, expectedRev.Config.AlertmanagerConfig, resetRev.Config.AlertmanagerConfig)
assert.Len(t, prov.Calls, 1) assert.Len(t, prov.Calls, 2)
assert.Equal(t, "DeleteProvenance", prov.Calls[0].MethodName) c := prov.Calls[0]
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context)) assert.Equal(t, "GetProvenance", c.MethodName)
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1]) assert.IsType(t, &definitions.Route{}, c.Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2]) assert.Equal(t, orgID, c.Arguments[2].(int64))
c = prov.Calls[1]
assert.Equal(t, "DeleteProvenance", c.MethodName)
assertInTransaction(t, c.Arguments[0].(context.Context))
assert.IsType(t, &definitions.Route{}, c.Arguments[1])
assert.Equal(t, orgID, c.Arguments[2])
}) })
} }
@ -286,6 +346,9 @@ func createNotificationPolicyServiceSut() (*NotificationPolicyService, *legacy_s
settings: setting.UnifiedAlertingSettings{ settings: setting.UnifiedAlertingSettings{
DefaultConfiguration: setting.GetAlertmanagerDefaultConfiguration(), DefaultConfiguration: setting.GetAlertmanagerDefaultConfiguration(),
}, },
validator: func(from, to models.Provenance) error {
return nil
},
}, configStore, prov }, configStore, prov
} }

View File

@ -45,7 +45,7 @@ func (c *defaultNotificationPolicyProvisioner) Unprovision(ctx context.Context,
files []*AlertingFile) error { files []*AlertingFile) error {
for _, file := range files { for _, file := range files {
for _, orgID := range file.ResetPolicies { for _, orgID := range file.ResetPolicies {
_, err := c.notificationPolicyService.ResetPolicyTree(ctx, int64(orgID)) _, err := c.notificationPolicyService.ResetPolicyTree(ctx, int64(orgID), models.ProvenanceFile)
if err != nil { if err != nil {
return fmt.Errorf("%s: %w", file.Filename, err) return fmt.Errorf("%s: %w", file.Filename, err)
} }