diff --git a/app/channel.go b/app/channel.go index b0c51bfe2f..01dd1366b4 100644 --- a/app/channel.go +++ b/app/channel.go @@ -1806,12 +1806,16 @@ func (a *App) MarkChannelAsUnreadFromPost(postID string, userID string) (*model. if err != nil { return channel, updateErr } + message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POST_UNREAD, channel.TeamId, channel.ChannelId, channel.UserId, nil) message.Add("msg_count", channel.MsgCount) message.Add("mention_count", channel.MentionCount) message.Add("last_viewed_at", channel.LastViewedAt) message.Add("post_id", postID) a.Publish(message) + + a.UpdateMobileAppBadge(userID) + return channel, nil } diff --git a/app/notification_push.go b/app/notification_push.go index 6eaed43f58..546ba2ff3a 100644 --- a/app/notification_push.go +++ b/app/notification_push.go @@ -20,6 +20,7 @@ type NotificationType string const NOTIFICATION_TYPE_CLEAR NotificationType = "clear" const NOTIFICATION_TYPE_MESSAGE NotificationType = "message" +const NOTIFICATION_TYPE_UPDATE_BADGE NotificationType = "update_badge" const PUSH_NOTIFICATION_HUB_WORKERS = 1000 const PUSH_NOTIFICATIONS_HUB_BUFFER_PER_WORKER = 50 @@ -53,16 +54,23 @@ func (hub *PushNotificationsHub) GetGoChannelFromUserId(userId string) chan Push func (a *App) sendPushNotificationSync(post *model.Post, user *model.User, channel *model.Channel, channelName string, senderName string, explicitMention bool, channelWideMention bool, replyToThreadType string) *model.AppError { - - sessions, err := a.getMobileAppSessions(user.Id) + msg, err := a.BuildPushNotificationMessage(post, user, channel, channelName, senderName, explicitMention, channelWideMention, replyToThreadType) if err != nil { return err } - msg := a.BuildPushNotificationMessage(post, user, channel, channelName, senderName, explicitMention, channelWideMention, replyToThreadType) + return a.sendPushNotificationToAllSessions(msg, user.Id, "") +} + +func (a *App) sendPushNotificationToAllSessions(msg *model.PushNotification, userId string, skipSessionId string) *model.AppError { + sessions, err := a.getMobileAppSessions(userId) + if err != nil { + return err + } for _, session := range sessions { - if session.IsExpired() { + // Don't send notifications to this session if it's expired or we want to skip it + if session.IsExpired() || (skipSessionId != "" && skipSessionId == session.Id) { continue } @@ -170,63 +178,22 @@ func (a *App) getPushNotificationMessage(postMessage string, explicitMention, ch return senderName + userLocale("api.post.send_notifications_and_forget.push_general_message") } -func (a *App) ClearPushNotificationSync(currentSessionId, userId, channelId string) { - sessions, err := a.getMobileAppSessions(userId) - if err != nil { - mlog.Error("error getting mobile app sessions", mlog.Err(err)) - return - } - - msg := model.PushNotification{ +func (a *App) ClearPushNotificationSync(currentSessionId, userId, channelId string) *model.AppError { + msg := &model.PushNotification{ Type: model.PUSH_TYPE_CLEAR, Version: model.PUSH_MESSAGE_V2, ChannelId: channelId, ContentAvailable: 1, } - if unreadCount, err := a.Srv.Store.User().GetUnreadCount(userId); err != nil { - msg.Badge = 0 - mlog.Error("We could not get the unread message count for", mlog.String("user_id", userId), mlog.Err(err)) - } else { - msg.Badge = int(unreadCount) + unreadCount, err := a.Srv.Store.User().GetUnreadCount(userId) + if err != nil { + return err } - for _, session := range sessions { - if currentSessionId != session.Id { - tmpMessage := model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) - tmpMessage.SetDeviceIdAndPlatform(session.DeviceId) - tmpMessage.AckId = model.NewId() + msg.Badge = int(unreadCount) - err := a.sendToPushProxy(*tmpMessage, session) - if err != nil { - a.NotificationsLog.Error("Notification error", - mlog.String("ackId", tmpMessage.AckId), - mlog.String("type", tmpMessage.Type), - mlog.String("userId", session.UserId), - mlog.String("postId", tmpMessage.PostId), - mlog.String("channelId", tmpMessage.ChannelId), - mlog.String("deviceId", tmpMessage.DeviceId), - mlog.String("status", err.Error()), - ) - - continue - } - - a.NotificationsLog.Info("Notification sent", - mlog.String("ackId", tmpMessage.AckId), - mlog.String("type", tmpMessage.Type), - mlog.String("userId", session.UserId), - mlog.String("postId", tmpMessage.PostId), - mlog.String("channelId", tmpMessage.ChannelId), - mlog.String("deviceId", tmpMessage.DeviceId), - mlog.String("status", model.PUSH_SEND_SUCCESS), - ) - - if a.Metrics != nil { - a.Metrics.IncrementPostSentPush() - } - } - } + return a.sendPushNotificationToAllSessions(msg, userId, currentSessionId) } func (a *App) ClearPushNotification(currentSessionId, userId, channelId string) { @@ -239,6 +206,32 @@ func (a *App) ClearPushNotification(currentSessionId, userId, channelId string) } } +func (a *App) UpdateMobileAppBadgeSync(userId string) *model.AppError { + msg := &model.PushNotification{ + Type: model.PUSH_TYPE_UPDATE_BADGE, + Version: model.PUSH_MESSAGE_V2, + Sound: "none", + ContentAvailable: 1, + } + + unreadCount, err := a.Srv.Store.User().GetUnreadCount(userId) + if err != nil { + return err + } + + msg.Badge = int(unreadCount) + + return a.sendPushNotificationToAllSessions(msg, userId, "") +} + +func (a *App) UpdateMobileAppBadge(userId string) { + channel := a.Srv.PushNotificationsHub.GetGoChannelFromUserId(userId) + channel <- PushNotification{ + notificationType: NOTIFICATION_TYPE_UPDATE_BADGE, + userId: userId, + } +} + func (a *App) CreatePushNotificationsHub() { hub := PushNotificationsHub{ Channels: []chan PushNotification{}, @@ -251,11 +244,13 @@ func (a *App) CreatePushNotificationsHub() { func (a *App) pushNotificationWorker(notifications chan PushNotification) { for notification := range notifications { + var err *model.AppError + switch notification.notificationType { case NOTIFICATION_TYPE_CLEAR: - a.ClearPushNotificationSync(notification.currentSessionId, notification.userId, notification.channelId) + err = a.ClearPushNotificationSync(notification.currentSessionId, notification.userId, notification.channelId) case NOTIFICATION_TYPE_MESSAGE: - a.sendPushNotificationSync( + err = a.sendPushNotificationSync( notification.post, notification.user, notification.channel, @@ -265,9 +260,15 @@ func (a *App) pushNotificationWorker(notifications chan PushNotification) { notification.channelWideMention, notification.replyToThreadType, ) + case NOTIFICATION_TYPE_UPDATE_BADGE: + err = a.UpdateMobileAppBadgeSync(notification.userId) default: mlog.Error("Invalid notification type", mlog.String("notification_type", string(notification.notificationType))) } + + if err != nil { + mlog.Error("Unable to send push notification", mlog.String("notification_type", string(notification.notificationType)), mlog.Err(err)) + } } } @@ -429,9 +430,9 @@ func DoesStatusAllowPushNotification(userNotifyProps model.StringMap, status *mo } func (a *App) BuildPushNotificationMessage(post *model.Post, user *model.User, channel *model.Channel, channelName string, senderName string, - explicitMention bool, channelWideMention bool, replyToThreadType string) model.PushNotification { + explicitMention bool, channelWideMention bool, replyToThreadType string) (*model.PushNotification, *model.AppError) { - msg := model.PushNotification{ + msg := &model.PushNotification{ Category: model.CATEGORY_CAN_REPLY, Version: model.PUSH_MESSAGE_V2, Type: model.PUSH_TYPE_MESSAGE, @@ -443,19 +444,19 @@ func (a *App) BuildPushNotificationMessage(post *model.Post, user *model.User, c } if user.NotifyProps["push"] == "all" { - if unreadCount, err := a.Srv.Store.User().GetAnyUnreadPostCountForChannel(user.Id, channel.Id); err != nil { - msg.Badge = 1 - mlog.Error("We could not get the unread message count for the user", mlog.String("user_id", user.Id), mlog.Err(err)) - } else { - msg.Badge = int(unreadCount) + unreadCount, err := a.Srv.Store.User().GetAnyUnreadPostCountForChannel(user.Id, channel.Id) + if err != nil { + return nil, err } + + msg.Badge = int(unreadCount) } else { - if unreadCount, err := a.Srv.Store.User().GetUnreadCount(user.Id); err != nil { - msg.Badge = 1 - mlog.Error("We could not get the unread message count for the user", mlog.String("user_id", user.Id), mlog.Err(err)) - } else { - msg.Badge = int(unreadCount) + unreadCount, err := a.Srv.Store.User().GetUnreadCount(user.Id) + if err != nil { + return nil, err } + + msg.Badge = int(unreadCount) } cfg := a.Config() @@ -483,5 +484,5 @@ func (a *App) BuildPushNotificationMessage(post *model.Post, user *model.User, c msg.Message = a.getPushNotificationMessage(post.Message, explicitMention, channelWideMention, hasFiles, msg.SenderName, channelName, channel.Type, replyToThreadType, userLocale) - return msg + return msg, nil } diff --git a/app/notification_push_test.go b/app/notification_push_test.go index a7f952df33..f34ff0c150 100644 --- a/app/notification_push_test.go +++ b/app/notification_push_test.go @@ -10,6 +10,7 @@ import ( "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/utils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDoesNotifyPropsAllowPushNotification(t *testing.T) { @@ -942,7 +943,8 @@ func TestBuildPushNotificationMessage(t *testing.T) { } { t.Run(name, func(t *testing.T) { receiver.NotifyProps["push"] = tc.pushNotifyProps - msg := th.App.BuildPushNotificationMessage(post, receiver, channel, channel.Name, sender.Username, tc.explicitMention, tc.channelWideMention, tc.replyToThreadType) + msg, err := th.App.BuildPushNotificationMessage(post, receiver, channel, channel.Name, sender.Username, tc.explicitMention, tc.channelWideMention, tc.replyToThreadType) + require.Nil(t, err) assert.Equal(t, tc.expectedBadge, msg.Badge) }) } diff --git a/go.mod b/go.mod index 66f2b180d6..f01c778fbf 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,8 @@ require ( github.com/miekg/dns v1.1.19 // indirect github.com/minio/minio-go/v6 v6.0.38 github.com/mitchellh/go-testing-interface v1.0.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect github.com/muesli/smartcrop v0.3.0 // indirect github.com/olekukonko/tablewriter v0.0.1 // indirect github.com/onsi/ginkgo v1.8.0 // indirect diff --git a/model/push_notification.go b/model/push_notification.go index 03b49bc897..e846f565d7 100644 --- a/model/push_notification.go +++ b/model/push_notification.go @@ -15,9 +15,12 @@ const ( PUSH_NOTIFY_APPLE_REACT_NATIVE = "apple_rn" PUSH_NOTIFY_ANDROID_REACT_NATIVE = "android_rn" - PUSH_TYPE_MESSAGE = "message" - PUSH_TYPE_CLEAR = "clear" - PUSH_MESSAGE_V2 = "v2" + PUSH_TYPE_MESSAGE = "message" + PUSH_TYPE_CLEAR = "clear" + PUSH_TYPE_UPDATE_BADGE = "update_badge" + PUSH_MESSAGE_V2 = "v2" + + PUSH_SOUND_NONE = "none" // The category is set to handle a set of interactive Actions // with the push notifications diff --git a/store/sqlstore/user_store.go b/store/sqlstore/user_store.go index 5febcc88e0..4c3acf993f 100644 --- a/store/sqlstore/user_store.go +++ b/store/sqlstore/user_store.go @@ -1129,7 +1129,7 @@ func (us SqlUserStore) AnalyticsActiveCount(timePeriod int64, options model.User return v, nil } -func (us SqlUserStore) GetUnreadCount(userId string) (int64, error) { +func (us SqlUserStore) GetUnreadCount(userId string) (int64, *model.AppError) { query := ` SELECT SUM(CASE WHEN c.Type = 'D' THEN (c.TotalMsgCount - cm.MsgCount) ELSE cm.MentionCount END) FROM Channels c diff --git a/store/store.go b/store/store.go index 0e76a10e21..a1385967b2 100644 --- a/store/store.go +++ b/store/store.go @@ -277,7 +277,7 @@ type UserStore interface { GetSystemAdminProfiles() (map[string]*model.User, *model.AppError) PermanentDelete(userId string) *model.AppError AnalyticsActiveCount(time int64, options model.UserCountOptions) (int64, *model.AppError) - GetUnreadCount(userId string) (int64, error) + GetUnreadCount(userId string) (int64, *model.AppError) GetUnreadCountForChannel(userId string, channelId string) (int64, *model.AppError) GetAnyUnreadPostCountForChannel(userId string, channelId string) (int64, *model.AppError) GetRecentlyActiveUsersForTeam(teamId string, offset, limit int, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError) diff --git a/store/storetest/mocks/UserStore.go b/store/storetest/mocks/UserStore.go index 43931cd758..c0bb32b9a1 100644 --- a/store/storetest/mocks/UserStore.go +++ b/store/storetest/mocks/UserStore.go @@ -810,7 +810,7 @@ func (_m *UserStore) GetTeamGroupUsers(teamID string) ([]*model.User, *model.App } // GetUnreadCount provides a mock function with given fields: userId -func (_m *UserStore) GetUnreadCount(userId string) (int64, error) { +func (_m *UserStore) GetUnreadCount(userId string) (int64, *model.AppError) { ret := _m.Called(userId) var r0 int64 @@ -820,11 +820,13 @@ func (_m *UserStore) GetUnreadCount(userId string) (int64, error) { r0 = ret.Get(0).(int64) } - var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string) *model.AppError); ok { r1 = rf(userId) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } } return r0, r1 diff --git a/store/timer_layer.go b/store/timer_layer.go index 3632de3d81..b4983eab0f 100644 --- a/store/timer_layer.go +++ b/store/timer_layer.go @@ -6504,7 +6504,7 @@ func (s *TimerLayerUserStore) GetTeamGroupUsers(teamID string) ([]*model.User, * return resultVar0, resultVar1 } -func (s *TimerLayerUserStore) GetUnreadCount(userId string) (int64, error) { +func (s *TimerLayerUserStore) GetUnreadCount(userId string) (int64, *model.AppError) { start := timemodule.Now() resultVar0, resultVar1 := s.UserStore.GetUnreadCount(userId)