Alerting: Support optimistic concurrency in notification policies service (#93932)

* update notification policy provisioning service to support optimistic concurrency
* rewrite tests and include concurrency tests
This commit is contained in:
Yuri Tseretyan
2024-10-07 17:09:02 -04:00
committed by GitHub
parent 0e8fa1f5f8
commit b8df574aba
6 changed files with 438 additions and 269 deletions

View File

@@ -51,8 +51,8 @@ type TemplateService interface {
} }
type NotificationPolicyService interface { type NotificationPolicyService interface {
GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error)
UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p alerting_models.Provenance) error UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p alerting_models.Provenance, version string) error
ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) ResetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error)
} }
@@ -79,7 +79,7 @@ type AlertRuleService interface {
} }
func (srv *ProvisioningSrv) RouteGetPolicyTree(c *contextmodel.ReqContext) response.Response { func (srv *ProvisioningSrv) RouteGetPolicyTree(c *contextmodel.ReqContext) response.Response {
policies, err := srv.policies.GetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID()) policies, _, err := srv.policies.GetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID())
if errors.Is(err, store.ErrNoAlertmanagerConfiguration) { if errors.Is(err, store.ErrNoAlertmanagerConfiguration) {
return ErrResp(http.StatusNotFound, err, "") return ErrResp(http.StatusNotFound, err, "")
} }
@@ -91,7 +91,7 @@ func (srv *ProvisioningSrv) RouteGetPolicyTree(c *contextmodel.ReqContext) respo
} }
func (srv *ProvisioningSrv) RouteGetPolicyTreeExport(c *contextmodel.ReqContext) response.Response { func (srv *ProvisioningSrv) RouteGetPolicyTreeExport(c *contextmodel.ReqContext) response.Response {
policies, err := srv.policies.GetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID()) policies, _, err := srv.policies.GetPolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID())
if err != nil { if err != nil {
if errors.Is(err, store.ErrNoAlertmanagerConfiguration) { if errors.Is(err, store.ErrNoAlertmanagerConfiguration) {
return ErrResp(http.StatusNotFound, err, "") return ErrResp(http.StatusNotFound, err, "")
@@ -109,7 +109,7 @@ func (srv *ProvisioningSrv) RouteGetPolicyTreeExport(c *contextmodel.ReqContext)
func (srv *ProvisioningSrv) RoutePutPolicyTree(c *contextmodel.ReqContext, tree definitions.Route) response.Response { func (srv *ProvisioningSrv) RoutePutPolicyTree(c *contextmodel.ReqContext, tree definitions.Route) response.Response {
provenance := determineProvenance(c) provenance := determineProvenance(c)
err := srv.policies.UpdatePolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID(), tree, alerting_models.Provenance(provenance)) err := srv.policies.UpdatePolicyTree(c.Req.Context(), c.SignedInUser.GetOrgID(), tree, alerting_models.Provenance(provenance), "")
if errors.Is(err, store.ErrNoAlertmanagerConfiguration) { if errors.Is(err, store.ErrNoAlertmanagerConfiguration) {
return ErrResp(http.StatusNotFound, err, "") return ErrResp(http.StatusNotFound, err, "")
} }
@@ -117,7 +117,7 @@ func (srv *ProvisioningSrv) RoutePutPolicyTree(c *contextmodel.ReqContext, tree
return ErrResp(http.StatusBadRequest, err, "") return ErrResp(http.StatusBadRequest, err, "")
} }
if err != nil { if err != nil {
return ErrResp(http.StatusInternalServerError, err, "") return response.ErrOrFallback(http.StatusInternalServerError, "", err)
} }
return response.JSON(http.StatusAccepted, util.DynMap{"message": "policies updated"}) return response.JSON(http.StatusAccepted, util.DynMap{"message": "policies updated"})

View File

@@ -153,7 +153,6 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 500, response.Status()) require.Equal(t, 500, response.Status())
require.NotEmpty(t, response.Body()) require.NotEmpty(t, response.Body())
require.Contains(t, string(response.Body()), "something went wrong")
}) })
t.Run("DELETE returns 500", func(t *testing.T) { t.Run("DELETE returns 500", func(t *testing.T) {
@@ -1985,16 +1984,16 @@ func createFakeNotificationPolicyService() *fakeNotificationPolicyService {
} }
} }
func (f *fakeNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error) {
if orgID != 1 { if orgID != 1 {
return definitions.Route{}, store.ErrNoAlertmanagerConfiguration return definitions.Route{}, "", store.ErrNoAlertmanagerConfiguration
} }
result := f.tree result := f.tree
result.Provenance = definitions.Provenance(f.prov) result.Provenance = definitions.Provenance(f.prov)
return result, nil return result, "", nil
} }
func (f *fakeNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance) error { func (f *fakeNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance, _ string) error {
if orgID != 1 { if orgID != 1 {
return store.ErrNoAlertmanagerConfiguration return store.ErrNoAlertmanagerConfiguration
} }
@@ -2010,11 +2009,11 @@ func (f *fakeNotificationPolicyService) ResetPolicyTree(ctx context.Context, org
type fakeFailingNotificationPolicyService struct{} type fakeFailingNotificationPolicyService struct{}
func (f *fakeFailingNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeFailingNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error) {
return definitions.Route{}, fmt.Errorf("something went wrong") return definitions.Route{}, "", fmt.Errorf("something went wrong")
} }
func (f *fakeFailingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance) error { func (f *fakeFailingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance, _ string) error {
return fmt.Errorf("something went wrong") return fmt.Errorf("something went wrong")
} }
@@ -2024,11 +2023,11 @@ func (f *fakeFailingNotificationPolicyService) ResetPolicyTree(ctx context.Conte
type fakeRejectingNotificationPolicyService struct{} type fakeRejectingNotificationPolicyService struct{}
func (f *fakeRejectingNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (f *fakeRejectingNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error) {
return definitions.Route{}, nil return definitions.Route{}, "", nil
} }
func (f *fakeRejectingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance) error { func (f *fakeRejectingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance, _ string) error {
return fmt.Errorf("%w: invalid policy tree", provisioning.ErrValidation) return fmt.Errorf("%w: invalid policy tree", provisioning.ErrValidation)
} }

View File

@@ -2,7 +2,15 @@ package provisioning
import ( import (
"context" "context"
"encoding/binary"
"fmt" "fmt"
"hash"
"hash/fnv"
"slices"
"unsafe"
"github.com/prometheus/common/model"
"golang.org/x/exp/maps"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions" "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
@@ -30,28 +38,27 @@ func NewNotificationPolicyService(am alertmanagerConfigStore, prov ProvisioningS
} }
} }
func (nps *NotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, error) { func (nps *NotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (definitions.Route, string, error) {
rev, err := nps.configStore.Get(ctx, orgID) rev, err := nps.configStore.Get(ctx, orgID)
if err != nil { if err != nil {
return definitions.Route{}, err return definitions.Route{}, "", err
} }
if rev.Config.AlertmanagerConfig.Config.Route == nil { if rev.Config.AlertmanagerConfig.Config.Route == nil {
return definitions.Route{}, fmt.Errorf("no route present in current alertmanager config") return definitions.Route{}, "", fmt.Errorf("no route present in current alertmanager config")
} }
provenance, err := nps.provenanceStore.GetProvenance(ctx, rev.Config.AlertmanagerConfig.Route, orgID) provenance, err := nps.provenanceStore.GetProvenance(ctx, rev.Config.AlertmanagerConfig.Route, orgID)
if err != nil { if err != nil {
return definitions.Route{}, err return definitions.Route{}, "", err
} }
result := *rev.Config.AlertmanagerConfig.Route result := *rev.Config.AlertmanagerConfig.Route
result.Provenance = definitions.Provenance(provenance) result.Provenance = definitions.Provenance(provenance)
version := calculateRouteFingerprint(result)
return result, nil return result, version, nil
} }
func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance) error { func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance, version string) error {
err := tree.Validate() err := tree.Validate()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrValidation, err.Error()) return fmt.Errorf("%w: %s", ErrValidation, err.Error())
@@ -62,6 +69,11 @@ func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgI
return err return err
} }
err = nps.checkOptimisticConcurrency(*revision.Config.AlertmanagerConfig.Route, models.Provenance(tree.Provenance), version, "update")
if err != nil {
return err
}
receivers, err := nps.receiversToMap(revision.Config.AlertmanagerConfig.Receivers) receivers, err := nps.receiversToMap(revision.Config.AlertmanagerConfig.Receivers)
if err != nil { if err != nil {
return err return err
@@ -154,3 +166,109 @@ func (nps *NotificationPolicyService) ensureDefaultReceiverExists(cfg *definitio
nps.log.Error("Grafana Alerting has been configured with a default configuration that is internally inconsistent! The default configuration's notification policy must have a corresponding receiver.") nps.log.Error("Grafana Alerting has been configured with a default configuration that is internally inconsistent! The default configuration's notification policy must have a corresponding receiver.")
return fmt.Errorf("inconsistent default configuration") return fmt.Errorf("inconsistent default configuration")
} }
func calculateRouteFingerprint(route definitions.Route) string {
sum := fnv.New64a()
writeToHash(sum, &route)
return fmt.Sprintf("%016x", sum.Sum64())
}
func writeToHash(sum hash.Hash, r *definitions.Route) {
writeBytes := func(b []byte) {
_, _ = sum.Write(b)
// add a byte sequence that cannot happen in UTF-8 strings.
_, _ = sum.Write([]byte{255})
}
writeString := func(s string) {
if len(s) == 0 {
writeBytes(nil)
return
}
// #nosec G103
// avoid allocation when converting string to byte slice
writeBytes(unsafe.Slice(unsafe.StringData(s), len(s)))
}
// this temp slice is used to convert ints to bytes.
tmp := make([]byte, 8)
writeInt := func(u int64) {
binary.LittleEndian.PutUint64(tmp, uint64(u))
writeBytes(tmp)
}
writeBool := func(b bool) {
if b {
writeInt(1)
} else {
writeInt(0)
}
}
writeDuration := func(d *model.Duration) {
if d == nil {
_, _ = sum.Write([]byte{255})
} else {
binary.LittleEndian.PutUint64(tmp, uint64(*d))
_, _ = sum.Write(tmp)
_, _ = sum.Write([]byte{255})
}
}
writeString(r.Receiver)
for _, s := range r.GroupByStr {
writeString(s)
}
for _, labelName := range r.GroupBy {
writeString(string(labelName))
}
writeBool(r.GroupByAll)
if len(r.Match) > 0 {
keys := maps.Keys(r.Match)
slices.Sort(keys)
for _, key := range keys {
writeString(key)
writeString(r.Match[key])
}
}
if len(r.MatchRE) > 0 {
keys := maps.Keys(r.MatchRE)
slices.Sort(keys)
for _, key := range keys {
writeString(key)
str, err := r.MatchRE[key].MarshalJSON()
if err != nil {
writeString(fmt.Sprintf("%+v", r.MatchRE))
}
writeBytes(str)
}
}
for _, matcher := range r.Matchers {
writeString(matcher.String())
}
for _, timeInterval := range r.MuteTimeIntervals {
writeString(timeInterval)
}
for _, timeInterval := range r.ActiveTimeIntervals {
writeString(timeInterval)
}
writeBool(r.Continue)
writeDuration(r.GroupWait)
writeDuration(r.GroupInterval)
writeDuration(r.RepeatInterval)
for _, route := range r.Routes {
writeToHash(sum, route)
}
}
func (nps *NotificationPolicyService) checkOptimisticConcurrency(current definitions.Route, provenance models.Provenance, desiredVersion string, action string) error {
if desiredVersion == "" {
if provenance != models.ProvenanceFile {
// if version is not specified and it's not a file provisioning, emit a log message to reflect that optimistic concurrency is disabled for this request
nps.log.Debug("ignoring optimistic concurrency check because version was not provided", "operation", action)
}
return nil
}
currentVersion := calculateRouteFingerprint(current)
if currentVersion != desiredVersion {
return ErrVersionConflict.Errorf("provided version %s of routing tree does not match current version %s", desiredVersion, currentVersion)
}
return nil
}

View File

@@ -4,10 +4,9 @@ import (
"context" "context"
"testing" "testing"
"github.com/grafana/alerting/definition"
"github.com/prometheus/alertmanager/config" "github.com/prometheus/alertmanager/config"
"github.com/prometheus/alertmanager/timeinterval" "github.com/stretchr/testify/assert"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
@@ -16,274 +15,304 @@ import (
"github.com/grafana/grafana/pkg/services/ngalert/notifier/legacy_storage" "github.com/grafana/grafana/pkg/services/ngalert/notifier/legacy_storage"
"github.com/grafana/grafana/pkg/services/ngalert/tests/fakes" "github.com/grafana/grafana/pkg/services/ngalert/tests/fakes"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
) )
func TestNotificationPolicyService(t *testing.T) { func TestGetPolicyTree(t *testing.T) {
t.Run("service gets policy tree from org's AM config", func(t *testing.T) { orgID := int64(1)
sut := createNotificationPolicyServiceSut() rev := getDefaultConfigRevision()
expectedVersion := calculateRouteFingerprint(*rev.Config.AlertmanagerConfig.Route)
tree, err := sut.GetPolicyTree(context.Background(), 1) sut, store, prov := createNotificationPolicyServiceSut()
require.NoError(t, err) store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
}
expectedProvenance := models.ProvenanceAPI
prov.GetProvenanceFunc = func(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) {
return models.ProvenanceAPI, nil
}
require.Equal(t, "grafana-default-email", tree.Receiver) tree, version, err := sut.GetPolicyTree(context.Background(), orgID)
}) require.NoError(t, err)
t.Run("error if referenced mute time interval is not existing", func(t *testing.T) { expectedRoute := *rev.Config.AlertmanagerConfig.Route
sut := createNotificationPolicyServiceSut() expectedRoute.Provenance = definitions.Provenance(models.ProvenanceAPI)
mockStore := &legacy_storage.MockAMConfigStore{} assert.Equal(t, expectedRoute, tree)
sut.configStore = legacy_storage.NewAlertmanagerConfigStore(mockStore) assert.Equal(t, expectedVersion, version)
cfg := createTestAlertingConfig() assert.Equal(t, expectedProvenance, models.Provenance(tree.Provenance))
cfg.AlertmanagerConfig.MuteTimeIntervals = []config.MuteTimeInterval{
assert.Len(t, store.Calls, 1)
assert.Equal(t, "Get", store.Calls[0].Method)
assert.Equal(t, orgID, store.Calls[0].Args[1])
assert.Len(t, prov.Calls, 1)
assert.Equal(t, "GetProvenance", prov.Calls[0].MethodName)
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2])
}
func TestUpdatePolicyTree(t *testing.T) {
orgID := int64(1)
rev := getDefaultConfigRevision()
defaultVersion := calculateRouteFingerprint(*rev.Config.AlertmanagerConfig.Route)
newRoute := definitions.Route{
Receiver: rev.Config.AlertmanagerConfig.Receivers[0].Name,
Routes: []*definitions.Route{
{ {
Name: "not-the-one-we-need", Receiver: "",
TimeIntervals: []timeinterval.TimeInterval{}, MuteTimeIntervals: []string{
rev.Config.AlertmanagerConfig.TimeIntervals[0].Name,
},
},
{
Receiver: rev.Config.AlertmanagerConfig.Receivers[0].Name,
},
},
}
t.Run("ErrValidation if referenced mute time interval does not exist", func(t *testing.T) {
sut, store, _ := createNotificationPolicyServiceSut()
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
}
newRoute := definitions.Route{
Receiver: rev.Config.AlertmanagerConfig.Receivers[0].Name,
MuteTimeIntervals: []string{
"not-existing",
}, },
} }
data, _ := legacy_storage.SerializeAlertmanagerConfig(*cfg) err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, defaultVersion)
mockStore.On("GetLatestAlertmanagerConfiguration", mock.Anything, mock.Anything).
Return(&models.AlertConfiguration{AlertmanagerConfiguration: string(data)}, nil)
mockStore.EXPECT().
UpdateAlertmanagerConfiguration(mock.Anything, mock.Anything).
Return(nil)
newRoute := createTestRoutingTree()
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "slack receiver",
MuteTimeIntervals: []string{"not-existing"},
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.Error(t, err)
})
t.Run("pass if referenced mute time interval is existing", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
mockStore := &legacy_storage.MockAMConfigStore{}
sut.configStore = legacy_storage.NewAlertmanagerConfigStore(mockStore)
cfg := createTestAlertingConfig()
cfg.AlertmanagerConfig.MuteTimeIntervals = []config.MuteTimeInterval{
{
Name: "existing",
TimeIntervals: []timeinterval.TimeInterval{},
},
}
cfg.AlertmanagerConfig.TimeIntervals = []config.TimeInterval{
{
Name: "existing-ti",
TimeIntervals: []timeinterval.TimeInterval{},
},
}
data, _ := legacy_storage.SerializeAlertmanagerConfig(*cfg)
mockStore.On("GetLatestAlertmanagerConfiguration", mock.Anything, mock.Anything).
Return(&models.AlertConfiguration{AlertmanagerConfiguration: string(data)}, nil)
mockStore.EXPECT().
UpdateAlertmanagerConfiguration(mock.Anything, mock.Anything).
Return(nil)
newRoute := createTestRoutingTree()
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "slack receiver",
MuteTimeIntervals: []string{"existing", "existing-ti"},
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.NoError(t, err)
})
t.Run("service stitches policy tree into org's AM config", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
newRoute := createTestRoutingTree()
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.NoError(t, err)
updated, err := sut.GetPolicyTree(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, "slack receiver", updated.Receiver)
})
t.Run("no root receiver will error", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
newRoute := createTestRoutingTree()
newRoute.Receiver = ""
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "",
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.EqualError(t, err, "invalid object specification: root route must specify a default receiver")
})
t.Run("allow receiver inheritance", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
newRoute := createTestRoutingTree()
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "",
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.NoError(t, err)
})
t.Run("not existing receiver reference will error", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
newRoute := createTestRoutingTree()
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "not-existing",
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.Error(t, err)
})
t.Run("existing receiver reference will pass", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
mockStore := &legacy_storage.MockAMConfigStore{}
sut.configStore = legacy_storage.NewAlertmanagerConfigStore(mockStore)
cfg := createTestAlertingConfig()
data, _ := legacy_storage.SerializeAlertmanagerConfig(*cfg)
mockStore.On("GetLatestAlertmanagerConfiguration", mock.Anything, mock.Anything).
Return(&models.AlertConfiguration{AlertmanagerConfiguration: string(data)}, nil)
mockStore.EXPECT().
UpdateAlertmanagerConfiguration(mock.Anything, mock.Anything).
Return(nil)
newRoute := createTestRoutingTree()
newRoute.Routes = append(newRoute.Routes, &definitions.Route{
Receiver: "existing",
})
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceNone)
require.NoError(t, err)
})
t.Run("default provenance of records is none", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
tree, err := sut.GetPolicyTree(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, models.ProvenanceNone, models.Provenance(tree.Provenance))
})
t.Run("service returns upgraded provenance value", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
newRoute := createTestRoutingTree()
err := sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceAPI)
require.NoError(t, err)
updated, err := sut.GetPolicyTree(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, models.ProvenanceAPI, models.Provenance(updated.Provenance))
})
t.Run("service respects concurrency token when updating", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
fake := fakes.NewFakeAlertmanagerConfigStore(defaultAlertmanagerConfigJSON)
sut.configStore = legacy_storage.NewAlertmanagerConfigStore(fake)
newRoute := createTestRoutingTree()
config, err := sut.configStore.Get(context.Background(), 1)
require.NoError(t, err)
expectedConcurrencyToken := config.ConcurrencyToken
err = sut.UpdatePolicyTree(context.Background(), 1, newRoute, models.ProvenanceAPI)
require.NoError(t, err)
intercepted := fake.LastSaveCommand
require.Equal(t, expectedConcurrencyToken, intercepted.FetchedConfigurationHash)
})
t.Run("updating invalid route returns ValidationError", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
invalid := createTestRoutingTree()
repeat := model.Duration(0)
invalid.RepeatInterval = &repeat
err := sut.UpdatePolicyTree(context.Background(), 1, invalid, models.ProvenanceNone)
require.Error(t, err)
require.ErrorIs(t, err, ErrValidation) require.ErrorIs(t, err, ErrValidation)
}) })
t.Run("deleting route replaces with default", func(t *testing.T) { t.Run("ErrValidation if root route has no receiver", func(t *testing.T) {
sut := createNotificationPolicyServiceSut() rev := getDefaultConfigRevision()
sut, store, _ := createNotificationPolicyServiceSut()
tree, err := sut.ResetPolicyTree(context.Background(), 1) store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
require.NoError(t, err) }
require.Equal(t, "grafana-default-email", tree.Receiver) newRoute := definitions.Route{
require.Nil(t, tree.Routes) Receiver: "",
require.Equal(t, []model.LabelName{models.FolderTitleLabel, model.AlertNameLabel}, tree.GroupBy) }
err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, defaultVersion)
require.ErrorIs(t, err, ErrValidation)
}) })
t.Run("deleting route with missing default receiver restores receiver", func(t *testing.T) { t.Run("ErrValidation if referenced receiver does not exist", func(t *testing.T) {
sut := createNotificationPolicyServiceSut() rev := getDefaultConfigRevision()
mockStore := &legacy_storage.MockAMConfigStore{} sut, store, _ := createNotificationPolicyServiceSut()
sut.configStore = legacy_storage.NewAlertmanagerConfigStore(mockStore) store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
cfg := createTestAlertingConfig() return &rev, nil
cfg.AlertmanagerConfig.Route = &definitions.Route{
Receiver: "slack receiver",
} }
cfg.AlertmanagerConfig.Receivers = []*definitions.PostableApiReceiver{ newRoute := definitions.Route{
{ Receiver: "unknown",
Receiver: config.Receiver{ }
Name: "slack receiver", err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, defaultVersion)
require.ErrorIs(t, err, ErrValidation)
t.Run("including sub-routes", func(t *testing.T) {
newRoute := definitions.Route{
Receiver: rev.Config.AlertmanagerConfig.Receivers[0].Name,
Routes: []*definitions.Route{
{Receiver: "unknown"},
}, },
}, }
// No default receiver! Only our custom one. err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, defaultVersion)
require.ErrorIs(t, err, ErrValidation)
})
})
t.Run("ErrVersionConflict if provided version does not match current", func(t *testing.T) {
rev := getDefaultConfigRevision()
sut, store, _ := createNotificationPolicyServiceSut()
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
} }
data, _ := legacy_storage.SerializeAlertmanagerConfig(*cfg) newRoute := definitions.Route{
mockStore.On("GetLatestAlertmanagerConfiguration", mock.Anything, mock.Anything). Receiver: rev.Config.AlertmanagerConfig.Receivers[0].Name,
Return(&models.AlertConfiguration{AlertmanagerConfiguration: string(data)}, nil) }
var interceptedSave = models.SaveAlertmanagerConfigurationCmd{} err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceNone, "wrong-version")
mockStore.EXPECT().SaveSucceedsIntercept(&interceptedSave) require.ErrorIs(t, err, ErrVersionConflict)
})
tree, err := sut.ResetPolicyTree(context.Background(), 1) t.Run("updates Route and sets provenance in transaction if route is valid and version matches", func(t *testing.T) {
sut, store, prov := createNotificationPolicyServiceSut()
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
}
expectedRev := getDefaultConfigRevision()
route := newRoute
expectedRev.ConcurrencyToken = rev.ConcurrencyToken
expectedRev.Config.AlertmanagerConfig.Route = &route
err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceAPI, defaultVersion)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "grafana-default-email", tree.Receiver)
require.NotEmpty(t, interceptedSave.AlertmanagerConfiguration) assert.Len(t, store.Calls, 2)
// Deserializing with no error asserts that the saved configStore is semantically valid. assert.Equal(t, "Save", store.Calls[1].Method)
newCfg, err := legacy_storage.DeserializeAlertmanagerConfig([]byte(interceptedSave.AlertmanagerConfiguration)) assertInTransaction(t, store.Calls[1].Args[0].(context.Context))
assert.Equal(t, &expectedRev, store.Calls[1].Args[1])
assert.Len(t, prov.Calls, 1)
assert.Equal(t, "SetProvenance", prov.Calls[0].MethodName)
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context))
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2].(int64))
assert.Equal(t, models.ProvenanceAPI, prov.Calls[0].Arguments[3].(models.Provenance))
})
t.Run("bypasses optimistic concurrency if provided version is empty", func(t *testing.T) {
sut, store, prov := createNotificationPolicyServiceSut()
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
return &rev, nil
}
expectedRev := getDefaultConfigRevision()
expectedRev.Config.AlertmanagerConfig.Route = &newRoute
expectedRev.ConcurrencyToken = rev.ConcurrencyToken
err := sut.UpdatePolicyTree(context.Background(), orgID, newRoute, models.ProvenanceAPI, "")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newCfg.AlertmanagerConfig.Receivers, 2)
assert.Len(t, store.Calls, 2)
assert.Equal(t, "Save", store.Calls[1].Method)
assertInTransaction(t, store.Calls[1].Args[0].(context.Context))
assert.Equal(t, &expectedRev, store.Calls[1].Args[1])
assert.Len(t, prov.Calls, 1)
assert.Equal(t, "SetProvenance", prov.Calls[0].MethodName)
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context))
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2].(int64))
assert.Equal(t, models.ProvenanceAPI, prov.Calls[0].Arguments[3].(models.Provenance))
}) })
} }
func createNotificationPolicyServiceSut() *NotificationPolicyService { func TestResetPolicyTree(t *testing.T) {
orgID := int64(1)
currentRevision := getDefaultConfigRevision()
currentRevision.Config.AlertmanagerConfig.Route = &definitions.Route{
Receiver: "receiver",
}
currentRevision.Config.TemplateFiles = map[string]string{
"test": "test",
}
currentRevision.Config.AlertmanagerConfig.TimeIntervals = []config.TimeInterval{
{
Name: "test",
},
}
currentRevision.Config.AlertmanagerConfig.Receivers = []*definitions.PostableApiReceiver{
{
Receiver: config.Receiver{Name: "receiver"},
PostableGrafanaReceivers: definitions.PostableGrafanaReceivers{
GrafanaManagedReceivers: []*definitions.PostableGrafanaReceiver{
{
UID: "test", Name: "test", Type: "email", Settings: []byte("{}"),
},
},
},
},
}
t.Run("Error if default config is invalid", func(t *testing.T) {
sut, _, _ := createNotificationPolicyServiceSut()
sut.settings = setting.UnifiedAlertingSettings{
DefaultConfiguration: "{",
}
_, err := sut.ResetPolicyTree(context.Background(), orgID)
require.ErrorContains(t, err, "failed to parse default alertmanager config")
})
t.Run("replaces route with one from the default config and copies receivers if do not exist", func(t *testing.T) {
defaultConfig := getDefaultConfigRevision().Config
data, err := legacy_storage.SerializeAlertmanagerConfig(*defaultConfig)
require.NoError(t, err)
sut, store, prov := createNotificationPolicyServiceSut()
sut.settings = setting.UnifiedAlertingSettings{
DefaultConfiguration: string(data),
}
store.GetFn = func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
data, err := legacy_storage.SerializeAlertmanagerConfig(*currentRevision.Config)
require.NoError(t, err)
cfg, err := legacy_storage.DeserializeAlertmanagerConfig(data)
require.NoError(t, err)
return &legacy_storage.ConfigRevision{
Config: cfg,
ConcurrencyToken: util.GenerateShortUID(),
}, nil
}
expectedRev := currentRevision
expectedRev.Config.AlertmanagerConfig.Route = getDefaultConfigRevision().Config.AlertmanagerConfig.Route
expectedRev.Config.AlertmanagerConfig.Receivers = append(expectedRev.Config.AlertmanagerConfig.Receivers, getDefaultConfigRevision().Config.AlertmanagerConfig.Receivers[0])
tree, err := sut.ResetPolicyTree(context.Background(), orgID)
require.NoError(t, err)
assert.Equal(t, *defaultConfig.AlertmanagerConfig.Route, tree)
assert.Len(t, store.Calls, 2)
assert.Equal(t, "Save", store.Calls[1].Method)
assertInTransaction(t, store.Calls[1].Args[0].(context.Context))
resetRev := store.Calls[1].Args[1].(*legacy_storage.ConfigRevision)
assert.Equal(t, expectedRev.Config.AlertmanagerConfig, resetRev.Config.AlertmanagerConfig)
assert.Len(t, prov.Calls, 1)
assert.Equal(t, "DeleteProvenance", prov.Calls[0].MethodName)
assertInTransaction(t, prov.Calls[0].Arguments[0].(context.Context))
assert.IsType(t, &definitions.Route{}, prov.Calls[0].Arguments[1])
assert.Equal(t, orgID, prov.Calls[0].Arguments[2])
})
}
func createNotificationPolicyServiceSut() (*NotificationPolicyService, *legacy_storage.AlertmanagerConfigStoreFake, *fakes.FakeProvisioningStore) {
prov := fakes.NewFakeProvisioningStore()
configStore := &legacy_storage.AlertmanagerConfigStoreFake{
GetFn: func(ctx context.Context, orgID int64) (*legacy_storage.ConfigRevision, error) {
rev := getDefaultConfigRevision()
return &rev, nil
},
}
return &NotificationPolicyService{ return &NotificationPolicyService{
configStore: legacy_storage.NewAlertmanagerConfigStore(fakes.NewFakeAlertmanagerConfigStore(defaultAlertmanagerConfigJSON)), configStore: configStore,
provenanceStore: fakes.NewFakeProvisioningStore(), provenanceStore: prov,
xact: newNopTransactionManager(), xact: newNopTransactionManager(),
log: log.NewNopLogger(), log: log.NewNopLogger(),
settings: setting.UnifiedAlertingSettings{ settings: setting.UnifiedAlertingSettings{
DefaultConfiguration: setting.GetAlertmanagerDefaultConfiguration(), DefaultConfiguration: setting.GetAlertmanagerDefaultConfiguration(),
}, },
} }, configStore, prov
} }
func createTestRoutingTree() definitions.Route { func getDefaultConfigRevision() legacy_storage.ConfigRevision {
return definitions.Route{ return legacy_storage.ConfigRevision{
Receiver: "slack receiver", Config: &definitions.PostableUserConfig{
AlertmanagerConfig: definitions.PostableApiAlertingConfig{
Config: definition.Config{
Route: &definitions.Route{
Receiver: "test-receiver",
},
InhibitRules: nil,
TimeIntervals: []config.TimeInterval{
{
Name: "test-mute-interval",
},
},
},
Receivers: []*definitions.PostableApiReceiver{
{
Receiver: config.Receiver{
Name: "test-receiver",
},
},
},
},
},
ConcurrencyToken: util.GenerateShortUID(),
} }
} }
func createTestAlertingConfig() *definitions.PostableUserConfig {
cfg, _ := legacy_storage.DeserializeAlertmanagerConfig([]byte(setting.GetAlertmanagerDefaultConfiguration()))
cfg.AlertmanagerConfig.Receivers = append(cfg.AlertmanagerConfig.Receivers,
&definitions.PostableApiReceiver{
Receiver: config.Receiver{
// default one from createTestRoutingTree()
Name: "slack receiver",
},
})
cfg.AlertmanagerConfig.Receivers = append(cfg.AlertmanagerConfig.Receivers,
&definitions.PostableApiReceiver{
Receiver: config.Receiver{
Name: "existing",
},
})
return cfg
}

View File

@@ -8,7 +8,12 @@ import (
) )
type FakeProvisioningStore struct { type FakeProvisioningStore struct {
Records map[int64]map[string]models.Provenance Calls []Call
Records map[int64]map[string]models.Provenance
GetProvenanceFunc func(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error)
GetProvenancesFunc func(ctx context.Context, orgID int64, resourceType string) (map[string]models.Provenance, error)
SetProvenanceFunc func(ctx context.Context, o models.Provisionable, org int64, p models.Provenance) error
DeleteProvenanceFunc func(ctx context.Context, o models.Provisionable, org int64) error
} }
func NewFakeProvisioningStore() *FakeProvisioningStore { func NewFakeProvisioningStore() *FakeProvisioningStore {
@@ -18,6 +23,10 @@ func NewFakeProvisioningStore() *FakeProvisioningStore {
} }
func (f *FakeProvisioningStore) GetProvenance(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) { func (f *FakeProvisioningStore) GetProvenance(ctx context.Context, o models.Provisionable, org int64) (models.Provenance, error) {
f.Calls = append(f.Calls, Call{MethodName: "GetProvenance", Arguments: []any{ctx, o, org}})
if f.GetProvenanceFunc != nil {
return f.GetProvenanceFunc(ctx, o, org)
}
if val, ok := f.Records[org]; ok { if val, ok := f.Records[org]; ok {
if prov, ok := val[o.ResourceID()+o.ResourceType()]; ok { if prov, ok := val[o.ResourceID()+o.ResourceType()]; ok {
return prov, nil return prov, nil
@@ -27,6 +36,10 @@ func (f *FakeProvisioningStore) GetProvenance(ctx context.Context, o models.Prov
} }
func (f *FakeProvisioningStore) GetProvenances(ctx context.Context, orgID int64, resourceType string) (map[string]models.Provenance, error) { func (f *FakeProvisioningStore) GetProvenances(ctx context.Context, orgID int64, resourceType string) (map[string]models.Provenance, error) {
f.Calls = append(f.Calls, Call{MethodName: "GetProvenances", Arguments: []any{ctx, orgID, resourceType}})
if f.GetProvenancesFunc != nil {
return f.GetProvenancesFunc(ctx, orgID, resourceType)
}
results := make(map[string]models.Provenance) results := make(map[string]models.Provenance)
if val, ok := f.Records[orgID]; ok { if val, ok := f.Records[orgID]; ok {
for k, v := range val { for k, v := range val {
@@ -39,15 +52,25 @@ func (f *FakeProvisioningStore) GetProvenances(ctx context.Context, orgID int64,
} }
func (f *FakeProvisioningStore) SetProvenance(ctx context.Context, o models.Provisionable, org int64, p models.Provenance) error { func (f *FakeProvisioningStore) SetProvenance(ctx context.Context, o models.Provisionable, org int64, p models.Provenance) error {
f.Calls = append(f.Calls, Call{MethodName: "SetProvenance", Arguments: []any{ctx, o, org, p}})
if f.SetProvenanceFunc != nil {
return f.SetProvenanceFunc(ctx, o, org, p)
}
if _, ok := f.Records[org]; !ok { if _, ok := f.Records[org]; !ok {
f.Records[org] = map[string]models.Provenance{} f.Records[org] = map[string]models.Provenance{}
} }
_ = f.DeleteProvenance(ctx, o, org) // delete old entries first if val, ok := f.Records[org]; ok {
delete(val, o.ResourceID()+o.ResourceType())
}
f.Records[org][o.ResourceID()+o.ResourceType()] = p f.Records[org][o.ResourceID()+o.ResourceType()] = p
return nil return nil
} }
func (f *FakeProvisioningStore) DeleteProvenance(ctx context.Context, o models.Provisionable, org int64) error { func (f *FakeProvisioningStore) DeleteProvenance(ctx context.Context, o models.Provisionable, org int64) error {
f.Calls = append(f.Calls, Call{MethodName: "DeleteProvenance", Arguments: []any{ctx, o, org}})
if f.DeleteProvenanceFunc != nil {
return f.DeleteProvenanceFunc(ctx, o, org)
}
if val, ok := f.Records[org]; ok { if val, ok := f.Records[org]; ok {
delete(val, o.ResourceID()+o.ResourceType()) delete(val, o.ResourceID()+o.ResourceType())
} }

View File

@@ -32,7 +32,7 @@ func (c *defaultNotificationPolicyProvisioner) Provision(ctx context.Context,
for _, file := range files { for _, file := range files {
for _, np := range file.Policies { for _, np := range file.Policies {
err := c.notificationPolicyService.UpdatePolicyTree(ctx, np.OrgID, err := c.notificationPolicyService.UpdatePolicyTree(ctx, np.OrgID,
np.Policy, models.ProvenanceFile) np.Policy, models.ProvenanceFile, "")
if err != nil { if err != nil {
return fmt.Errorf("%s: %w", file.Filename, err) return fmt.Errorf("%s: %w", file.Filename, err)
} }