Extract Route validation from serialization methods so it can be re-used (#47649)

* Extract validation and reject invalid policies

* Validation in dedicated file

* Tests for validation

* Extract root route validation

* Update call and drop TODO

* empty commit to kick actions

* Normalization should be idempotent

* Cleaner representation of validation errors, chain errors properly

* Make internal validate unexported

* Fix missed rename

* Genericize error message

* Improve method names

* Rebase, fix

* Update asserts
This commit is contained in:
Alexander Weaver 2022-04-27 15:15:41 -05:00 committed by GitHub
parent 900d9bf9a1
commit 60ec10566f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 380 additions and 50 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/models"
apimodels "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
alerting_models "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/provisioning"
"github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/web"
@ -51,6 +52,9 @@ func (srv *ProvisioningSrv) RoutePostPolicyTree(c *models.ReqContext, tree apimo
if errors.Is(err, store.ErrNoAlertmanagerConfiguration) {
return ErrResp(http.StatusNotFound, err, "")
}
if errors.Is(err, provisioning.ErrValidation) {
return ErrResp(http.StatusBadRequest, err, "")
}
if err != nil {
return ErrResp(http.StatusInternalServerError, err, "")
}

View File

@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/models"
apimodels "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
domain "github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/grafana/grafana/pkg/services/ngalert/provisioning"
"github.com/grafana/grafana/pkg/services/ngalert/store"
"github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/require"
@ -35,7 +36,20 @@ func TestProvisioningApi(t *testing.T) {
require.Equal(t, 202, response.Status())
})
// TODO: we have not lifted out validation yet. Test that we are returning errors properly once validation has been lifted.
t.Run("when new policy tree is invalid", func(t *testing.T) {
t.Run("POST policies returns 400", func(t *testing.T) {
sut := createProvisioningSrvSut()
sut.policies = &fakeRejectingNotificationPolicyService{}
rc := createTestRequestCtx()
tree := apimodels.Route{}
response := sut.RoutePostPolicyTree(&rc, tree)
require.Equal(t, 400, response.Status())
expBody := `{"error":"invalid object specification: invalid policy tree","message":"invalid object specification: invalid policy tree"}`
require.Equal(t, expBody, string(response.Body()))
})
})
t.Run("when org has no AM config", func(t *testing.T) {
t.Run("GET policies returns 404", func(t *testing.T) {
@ -146,3 +160,13 @@ func (f *fakeFailingNotificationPolicyService) GetPolicyTree(ctx context.Context
func (f *fakeFailingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree apimodels.Route, p domain.Provenance) error {
return fmt.Errorf("something went wrong")
}
type fakeRejectingNotificationPolicyService struct{}
func (f *fakeRejectingNotificationPolicyService) GetPolicyTree(ctx context.Context, orgID int64) (apimodels.Route, error) {
return apimodels.Route{}, nil
}
func (f *fakeRejectingNotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree apimodels.Route, p domain.Provenance) error {
return fmt.Errorf("%w: invalid policy tree", provisioning.ErrValidation)
}

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
amv2 "github.com/prometheus/alertmanager/api/v2/models"
"github.com/prometheus/alertmanager/config"
"github.com/prometheus/alertmanager/pkg/labels"
@ -724,35 +723,7 @@ func (r *Route) UnmarshalYAML(unmarshal func(interface{}) error) error {
return err
}
for _, l := range r.GroupByStr {
if l == "..." {
r.GroupByAll = true
} else {
r.GroupBy = append(r.GroupBy, model.LabelName(l))
}
}
if len(r.GroupBy) > 0 && r.GroupByAll {
return fmt.Errorf("cannot have wildcard group_by (`...`) and other other labels at the same time")
}
groupBy := map[model.LabelName]struct{}{}
for _, ln := range r.GroupBy {
if _, ok := groupBy[ln]; ok {
return fmt.Errorf("duplicated label %q in group_by", ln)
}
groupBy[ln] = struct{}{}
}
if r.GroupInterval != nil && time.Duration(*r.GroupInterval) == time.Duration(0) {
return fmt.Errorf("group_interval cannot be zero")
}
if r.RepeatInterval != nil && time.Duration(*r.RepeatInterval) == time.Duration(0) {
return fmt.Errorf("repeat_interval cannot be zero")
}
return nil
return r.validateChild()
}
// Return an alertmanager route from a Grafana route. The ObjectMatchers are converted to Matchers.
@ -837,25 +808,9 @@ func (c *Config) UnmarshalJSON(b []byte) error {
return fmt.Errorf("no routes provided")
}
// Route is a recursive structure that includes validation in the yaml unmarshaler.
// Therefore, we'll redirect json -> yaml to utilize these.
b, err := yaml.Marshal(c.Route)
err := c.Route.Validate()
if err != nil {
return errors.Wrap(err, "marshaling route to yaml for validation")
}
err = yaml.Unmarshal(b, c.Route)
if err != nil {
return errors.Wrap(err, "unmarshaling route for validations")
}
if len(c.Route.Receiver) == 0 {
return fmt.Errorf("root route must specify a default receiver")
}
if len(c.Route.Match) > 0 || len(c.Route.MatchRE) > 0 {
return fmt.Errorf("root route must not have any matchers")
}
if len(c.Route.MuteTimeIntervals) > 0 {
return fmt.Errorf("root route must not have any mute time intervals")
return err
}
for _, r := range c.InhibitRules {

View File

@ -0,0 +1,67 @@
package definitions
import (
"fmt"
"time"
"github.com/prometheus/common/model"
)
// Validate normalizes a possibly nested Route r, and returns errors if r is invalid.
func (r *Route) validateChild() error {
r.GroupBy = nil
r.GroupByAll = false
for _, l := range r.GroupByStr {
if l == "..." {
r.GroupByAll = true
} else {
r.GroupBy = append(r.GroupBy, model.LabelName(l))
}
}
if len(r.GroupBy) > 0 && r.GroupByAll {
return fmt.Errorf("cannot have wildcard group_by (`...`) and other other labels at the same time")
}
groupBy := map[model.LabelName]struct{}{}
for _, ln := range r.GroupBy {
if _, ok := groupBy[ln]; ok {
return fmt.Errorf("duplicated label %q in group_by, %s %s", ln, r.Receiver, r.GroupBy)
}
groupBy[ln] = struct{}{}
}
if r.GroupInterval != nil && time.Duration(*r.GroupInterval) == time.Duration(0) {
return fmt.Errorf("group_interval cannot be zero")
}
if r.RepeatInterval != nil && time.Duration(*r.RepeatInterval) == time.Duration(0) {
return fmt.Errorf("repeat_interval cannot be zero")
}
// Routes are a self-referential structure.
if r.Routes != nil {
for _, child := range r.Routes {
err := child.validateChild()
if err != nil {
return err
}
}
}
return nil
}
// Validate normalizes a Route r, and returns errors if r is an invalid root route. Root routes must satisfy a few additional conditions.
func (r *Route) Validate() error {
if len(r.Receiver) == 0 {
return fmt.Errorf("root route must specify a default receiver")
}
if len(r.Match) > 0 || len(r.MatchRE) > 0 {
return fmt.Errorf("root route must not have any matchers")
}
if len(r.MuteTimeIntervals) > 0 {
return fmt.Errorf("root route must not have any mute time intervals")
}
return r.validateChild()
}

View File

@ -0,0 +1,257 @@
package definitions
import (
"testing"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/require"
)
func TestValidateRoutes(t *testing.T) {
zero := model.Duration(0)
type testCase struct {
desc string
route Route
expMsg string
}
t.Run("valid route", func(t *testing.T) {
cases := []testCase{
{
desc: "empty",
route: Route{},
},
{
desc: "simple",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
},
},
{
desc: "nested",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
Routes: []*Route{
{
Receiver: "bar",
},
},
},
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := c.route.validateChild()
require.NoError(t, err)
})
}
})
t.Run("invalid route", func(t *testing.T) {
cases := []testCase{
{
desc: "zero group interval",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
GroupInterval: &zero,
},
expMsg: "group_interval cannot be zero",
},
{
desc: "zero repeat interval",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
RepeatInterval: &zero,
},
expMsg: "repeat_interval cannot be zero",
},
{
desc: "duplicated label",
route: Route{
Receiver: "foo",
GroupByStr: []string{
"abc",
"abc",
},
},
expMsg: "duplicated label",
},
{
desc: "wildcard and non-wildcard label simultaneously",
route: Route{
Receiver: "foo",
GroupByStr: []string{
"...",
"abc",
},
},
expMsg: "cannot have wildcard",
},
{
desc: "valid with nested invalid",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
Routes: []*Route{
{
GroupByStr: []string{
"abc",
"abc",
},
},
},
},
expMsg: "duplicated label",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := c.route.validateChild()
require.Error(t, err)
require.Contains(t, err.Error(), c.expMsg)
})
}
})
t.Run("route validator normalizes group_by", func(t *testing.T) {
t.Run("when grouping normally", func(t *testing.T) {
route := Route{
Receiver: "foo",
GroupByStr: []string{"abc", "def"},
}
_ = route.validateChild()
require.False(t, route.GroupByAll)
require.Equal(t, []model.LabelName{"abc", "def"}, route.GroupBy)
})
t.Run("when grouping by wildcard, nil", func(t *testing.T) {
route := Route{
Receiver: "foo",
GroupByStr: []string{"..."},
}
_ = route.validateChild()
require.True(t, route.GroupByAll)
require.Nil(t, route.GroupBy)
})
t.Run("idempotently", func(t *testing.T) {
route := Route{
Receiver: "foo",
GroupByStr: []string{"abc", "def"},
}
err := route.validateChild()
require.NoError(t, err)
err = route.validateChild()
require.NoError(t, err)
require.False(t, route.GroupByAll)
require.Equal(t, []model.LabelName{"abc", "def"}, route.GroupBy)
})
})
t.Run("valid root route", func(t *testing.T) {
cases := []testCase{
{
desc: "simple",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
},
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := c.route.Validate()
require.NoError(t, err)
})
}
})
t.Run("invalid root route", func(t *testing.T) {
cases := []testCase{
{
desc: "no receiver",
route: Route{
GroupByStr: []string{"..."},
},
expMsg: "must specify a default receiver",
},
{
desc: "exact matchers present",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
Match: map[string]string{
"abc": "def",
},
},
expMsg: "must not have any matchers",
},
{
desc: "regex matchers present",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
Match: map[string]string{
"abc": "def",
},
},
expMsg: "must not have any matchers",
},
{
desc: "mute time intervals present",
route: Route{
Receiver: "foo",
GroupByStr: []string{"..."},
MuteTimeIntervals: []string{"10"},
},
expMsg: "must not have any mute time intervals",
},
{
desc: "validation error that is not specific to root",
route: Route{
Receiver: "foo",
GroupByStr: []string{"abc", "abc"},
},
expMsg: "duplicated label",
},
{
desc: "nested validation error that is not specific to root",
route: Route{
Receiver: "foo",
Routes: []*Route{
{
GroupByStr: []string{"abc", "abc"},
},
},
},
expMsg: "duplicated label",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := c.route.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), c.expMsg)
})
}
})
}

View File

@ -59,10 +59,15 @@ func (nps *NotificationPolicyService) GetPolicyTree(ctx context.Context, orgID i
}
func (nps *NotificationPolicyService) UpdatePolicyTree(ctx context.Context, orgID int64, tree definitions.Route, p models.Provenance) error {
err := tree.Validate()
if err != nil {
return fmt.Errorf("%w: %s", ErrValidation, err.Error())
}
q := models.GetLatestAlertmanagerConfigurationQuery{
OrgID: orgID,
}
err := nps.amStore.GetLatestAlertmanagerConfiguration(ctx, &q)
err = nps.amStore.GetLatestAlertmanagerConfiguration(ctx, &q)
if err != nil {
return err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions"
"github.com/grafana/grafana/pkg/services/ngalert/models"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/require"
)
@ -70,6 +71,18 @@ func TestNotificationPolicyService(t *testing.T) {
intercepted := fake.lastSaveCommand
require.Equal(t, expectedConcurrencyToken, intercepted.FetchedConfigurationHash)
})
t.Run("updating invalid route returns ValidationError", func(t *testing.T) {
sut := createNotificationPolicyServiceSut()
invalid := createTestRoutingTree()
repeat := model.Duration(0)
invalid.RepeatInterval = &repeat
err := sut.UpdatePolicyTree(context.Background(), 1, invalid, models.ProvenanceNone)
require.Error(t, err)
require.ErrorIs(t, err, ErrValidation)
})
}
func createNotificationPolicyServiceSut() *NotificationPolicyService {

View File

@ -0,0 +1,5 @@
package provisioning
import "fmt"
var ErrValidation = fmt.Errorf("invalid object specification")