From b8d1474609206a80325536164420a93454ed18a5 Mon Sep 17 00:00:00 2001 From: Alexander Weaver Date: Tue, 6 Sep 2022 14:51:54 -0500 Subject: [PATCH] Fix incorrect propagation of org ID in rule endpionts (#54603) --- pkg/services/ngalert/api/api_provisioning.go | 14 +++--- .../ngalert/api/api_provisioning_test.go | 46 +++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/pkg/services/ngalert/api/api_provisioning.go b/pkg/services/ngalert/api/api_provisioning.go index c523084d822..daa16e7e936 100644 --- a/pkg/services/ngalert/api/api_provisioning.go +++ b/pkg/services/ngalert/api/api_provisioning.go @@ -255,6 +255,7 @@ func (srv *ProvisioningSrv) RouteRouteGetAlertRule(c *models.ReqContext, UID str func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definitions.ProvisionedAlertRule) response.Response { upstreamModel, err := ar.UpstreamModel() + upstreamModel.OrgID = c.OrgID if err != nil { ErrResp(http.StatusBadRequest, err, "") } @@ -271,10 +272,9 @@ func (srv *ProvisioningSrv) RoutePostAlertRule(c *models.ReqContext, ar definiti } return ErrResp(http.StatusInternalServerError, err, "") } - ar.ID = createdAlertRule.ID - ar.UID = createdAlertRule.UID - ar.Updated = createdAlertRule.Updated - return response.JSON(http.StatusCreated, ar) + + resp := definitions.NewAlertRule(createdAlertRule, alerting_models.ProvenanceAPI) + return response.JSON(http.StatusCreated, resp) } func (srv *ProvisioningSrv) RoutePutAlertRule(c *models.ReqContext, ar definitions.ProvisionedAlertRule, UID string) response.Response { @@ -282,6 +282,7 @@ func (srv *ProvisioningSrv) RoutePutAlertRule(c *models.ReqContext, ar definitio if err != nil { ErrResp(http.StatusBadRequest, err, "") } + updated.OrgID = c.OrgID updated.UID = UID updatedAlertRule, err := srv.alertRules.UpdateAlertRule(c.Req.Context(), updated, alerting_models.ProvenanceAPI) if errors.Is(err, alerting_models.ErrAlertRuleNotFound) { @@ -296,8 +297,9 @@ func (srv *ProvisioningSrv) RoutePutAlertRule(c *models.ReqContext, ar definitio } return ErrResp(http.StatusInternalServerError, err, "") } - ar.Updated = updatedAlertRule.Updated - return response.JSON(http.StatusOK, ar) + + resp := definitions.NewAlertRule(updatedAlertRule, alerting_models.ProvenanceAPI) + return response.JSON(http.StatusOK, resp) } func (srv *ProvisioningSrv) RouteDeleteAlertRule(c *models.ReqContext, UID string) response.Response { diff --git a/pkg/services/ngalert/api/api_provisioning_test.go b/pkg/services/ngalert/api/api_provisioning_test.go index 9143151dbe1..87458dce7ea 100644 --- a/pkg/services/ngalert/api/api_provisioning_test.go +++ b/pkg/services/ngalert/api/api_provisioning_test.go @@ -257,6 +257,38 @@ func TestProvisioningApi(t *testing.T) { }) }) + t.Run("exist in non-default orgs", func(t *testing.T) { + t.Run("POST sets expected fields", func(t *testing.T) { + sut := createProvisioningSrvSut(t) + rc := createTestRequestCtx() + rc.OrgID = 3 + rule := createTestAlertRule("rule", 1) + + response := sut.RoutePostAlertRule(&rc, rule) + + require.Equal(t, 201, response.Status()) + created := deserializeRule(t, response.Body()) + require.Equal(t, int64(3), created.OrgID) + }) + + t.Run("PUT sets expected fields", func(t *testing.T) { + sut := createProvisioningSrvSut(t) + uid := t.Name() + rule := createTestAlertRule("rule", 1) + rule.UID = uid + insertRuleInOrg(t, sut, rule, 3) + rc := createTestRequestCtx() + rc.OrgID = 3 + rule.OrgID = 1 // Set the org back to something wrong, we should still prefer the value from the req context. + + response := sut.RoutePutAlertRule(&rc, rule, rule.UID) + + require.Equal(t, 200, response.Status()) + created := deserializeRule(t, response.Body()) + require.Equal(t, int64(3), created.OrgID) + }) + }) + t.Run("are missing, PUT returns 404", func(t *testing.T) { sut := createProvisioningSrvSut(t) rc := createTestRequestCtx() @@ -543,13 +575,27 @@ func createTestAlertRule(title string, orgID int64) definitions.ProvisionedAlert } func insertRule(t *testing.T, srv ProvisioningSrv, rule definitions.ProvisionedAlertRule) { + insertRuleInOrg(t, srv, rule, 1) +} + +func insertRuleInOrg(t *testing.T, srv ProvisioningSrv, rule definitions.ProvisionedAlertRule, orgID int64) { t.Helper() rc := createTestRequestCtx() + rc.OrgID = orgID resp := srv.RoutePostAlertRule(&rc, rule) require.Equal(t, 201, resp.Status()) } +func deserializeRule(t *testing.T, data []byte) definitions.ProvisionedAlertRule { + t.Helper() + + var rule definitions.ProvisionedAlertRule + err := json.Unmarshal(data, &rule) + require.NoError(t, err) + return rule +} + var testConfig = ` { "template_files": {