diff --git a/api4/system.go b/api4/system.go index f28311d488..e2e5390e47 100644 --- a/api4/system.go +++ b/api4/system.go @@ -411,14 +411,23 @@ func getRedirectLocation(c *Context, w http.ResponseWriter, r *http.Request) { } func pushNotificationAck(c *Context, w http.ResponseWriter, r *http.Request) { - ack := model.PushNotificationAckFromJson(r.Body) + ack, err := model.PushNotificationAckFromJson(r.Body) + if err != nil { + c.Err = model.NewAppError("pushNotificationAck", + "api.push_notifications_ack.message.parse.app_error", + nil, + err.Error(), + http.StatusBadRequest, + ) + return + } if !*c.App.Config().EmailSettings.SendPushNotifications { c.Err = model.NewAppError("pushNotificationAck", "api.push_notification.disabled.app_error", nil, "", http.StatusNotImplemented) return } - err := c.App.SendAckToPushProxy(ack) + err = c.App.SendAckToPushProxy(ack) if ack.IsIdLoaded { if err != nil { // Log the error only, then continue to fetch notification message diff --git a/api4/system_test.go b/api4/system_test.go index a38d815fcb..8936f89959 100644 --- a/api4/system_test.go +++ b/api4/system_test.go @@ -627,3 +627,20 @@ func TestServerBusy503(t *testing.T) { CheckNoError(t, resp) }) } + +func TestPushNotificationAck(t *testing.T) { + th := Setup().InitBasic() + api := Init(th.Server, th.Server.AppOptions, th.Server.Router) + session, _ := th.App.GetSession(th.Client.AuthToken) + defer th.TearDown() + t.Run("should return error when the ack body is not passed", func(t *testing.T) { + handler := api.ApiHandler(pushNotificationAck) + resp := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/v4/notifications/ack", nil) + req.Header.Set(model.HEADER_AUTH, "Bearer "+session.Token) + + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusBadRequest, resp.Code) + assert.NotNil(t, resp.Body) + }) +} diff --git a/app/notification_push.go b/app/notification_push.go index 7c51ec96b3..71a6581722 100644 --- a/app/notification_push.go +++ b/app/notification_push.go @@ -78,13 +78,35 @@ func (a *App) sendPushNotificationToAllSessions(msg *model.PushNotification, use return err } + if msg == nil { + return model.NewAppError( + "pushNotification", + "api.push_notifications.message.parse.app_error", + nil, + "", + http.StatusBadRequest, + ) + } + + notification, parseError := model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) + if parseError != nil { + return model.NewAppError( + "pushNotification", + "api.push_notifications.message.parse.app_error", + nil, + parseError.Error(), + http.StatusInternalServerError, + ) + } + for _, session := range sessions { // Don't send notifications to this session if it's expired or we want to skip it if session.IsExpired() || (skipSessionId != "" && skipSessionId == session.Id) { continue } - tmpMessage := model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) + // We made a copy to avoid decoding and parsing all the time + tmpMessage := notification tmpMessage.SetDeviceIdAndPlatform(session.DeviceId) tmpMessage.AckId = model.NewId() diff --git a/app/notification_push_test.go b/app/notification_push_test.go index 0830ae3634..b400f466ad 100644 --- a/app/notification_push_test.go +++ b/app/notification_push_test.go @@ -955,3 +955,22 @@ func TestBuildPushNotificationMessageMentions(t *testing.T) { }) } } + +func TestSendPushNotifications(t *testing.T) { + th := Setup(t).InitBasic() + th.App.CreateSession(&model.Session{ + UserId: th.BasicUser.Id, + DeviceId: "test", + ExpiresAt: model.GetMillis() + 100000, + }) + defer th.TearDown() + + t.Run("should return error if data is not valid or nil", func(t *testing.T) { + err := th.App.sendPushNotificationToAllSessions(nil, th.BasicUser.Id, "") + assert.NotNil(t, err) + assert.Equal(t, "pushNotification: An error occurred building the push notification message, ", err.Error()) + // Errors derived of using an empty object are handled internally through the notifications log + err = th.App.sendPushNotificationToAllSessions(&model.PushNotification{}, th.BasicUser.Id, "") + assert.Nil(t, err) + }) +} diff --git a/i18n/en.json b/i18n/en.json index 0c7cba4b1f..2405e793b2 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -1762,10 +1762,18 @@ "id": "api.push_notification.id_loaded.fetch.app_error", "translation": "An error occurred fetching the ID-loaded push notification" }, + { + "id": "api.push_notifications.message.parse.app_error", + "translation": "An error occurred building the push notification message" + }, { "id": "api.push_notifications_ack.forward.app_error", "translation": "An error occurred sending the receipt delivery to the push notification service" }, + { + "id": "api.push_notifications_ack.message.parse.app_error", + "translation": "An error occurred building the push notification ack message" + }, { "id": "api.reaction.delete.archived_channel.app_error", "translation": "You cannot remove a reaction in an archived channel." diff --git a/model/push_notification.go b/model/push_notification.go index 78711abbde..a601f85a02 100644 --- a/model/push_notification.go +++ b/model/push_notification.go @@ -5,6 +5,7 @@ package model import ( "encoding/json" + "errors" "io" "strings" ) @@ -83,16 +84,26 @@ func (me *PushNotification) SetDeviceIdAndPlatform(deviceId string) { } } -func PushNotificationFromJson(data io.Reader) *PushNotification { +func PushNotificationFromJson(data io.Reader) (*PushNotification, error) { + if data == nil { + return nil, errors.New("push notification data can't be nil") + } var me *PushNotification - json.NewDecoder(data).Decode(&me) - return me + if err := json.NewDecoder(data).Decode(&me); err != nil { + return nil, err + } + return me, nil } -func PushNotificationAckFromJson(data io.Reader) *PushNotificationAck { +func PushNotificationAckFromJson(data io.Reader) (*PushNotificationAck, error) { + if data == nil { + return nil, errors.New("push notification data can't be nil") + } var ack *PushNotificationAck - json.NewDecoder(data).Decode(&ack) - return ack + if err := json.NewDecoder(data).Decode(&ack); err != nil { + return nil, err + } + return ack, nil } func (ack *PushNotificationAck) ToJson() string { diff --git a/model/push_notification_test.go b/model/push_notification_test.go index 8fe0980cff..82973b4f9d 100644 --- a/model/push_notification_test.go +++ b/model/push_notification_test.go @@ -11,11 +11,49 @@ import ( ) func TestPushNotification(t *testing.T) { - msg := PushNotification{Platform: "test"} - json := msg.ToJson() - result := PushNotificationFromJson(strings.NewReader(json)) + t.Run("should build a push notification from JSON", func(t *testing.T) { + msg := PushNotification{Platform: "test"} + json := msg.ToJson() + result, err := PushNotificationFromJson(strings.NewReader(json)) - require.Equal(t, msg.Platform, result.Platform, "Ids do not match") + require.Nil(t, err) + require.Equal(t, msg.Platform, result.Platform, "ids do not match") + }) + + t.Run("should throw an error when the message is nil", func(t *testing.T) { + _, err := PushNotificationFromJson(nil) + require.NotNil(t, err) + require.Equal(t, "push notification data can't be nil", err.Error()) + }) + + t.Run("should throw an error when the message parsing fails", func(t *testing.T) { + _, err := PushNotificationFromJson(strings.NewReader("")) + require.NotNil(t, err) + require.Equal(t, "EOF", err.Error()) + }) +} + +func TestPushNotificationAck(t *testing.T) { + t.Run("should build a push notification ack from JSON", func(t *testing.T) { + msg := PushNotificationAck{ClientPlatform: "test"} + json := msg.ToJson() + result, err := PushNotificationAckFromJson(strings.NewReader(json)) + + require.Nil(t, err) + require.Equal(t, msg.ClientPlatform, result.ClientPlatform, "ids do not match") + }) + + t.Run("should throw an error when the message is nil", func(t *testing.T) { + _, err := PushNotificationAckFromJson(nil) + require.NotNil(t, err) + require.Equal(t, "push notification data can't be nil", err.Error()) + }) + + t.Run("should throw an error when the message parsing fails", func(t *testing.T) { + _, err := PushNotificationAckFromJson(strings.NewReader("")) + require.NotNil(t, err) + require.Equal(t, "EOF", err.Error()) + }) } func TestPushNotificationDeviceId(t *testing.T) {