diff --git a/server/channels/app/notification.go b/server/channels/app/notification.go index 8696577104..2a0cc3719d 100644 --- a/server/channels/app/notification.go +++ b/server/channels/app/notification.go @@ -517,12 +517,12 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea } } - if len(mentionedUsersList) != 0 { - message.Add("mentions", model.ArrayToJSON(mentionedUsersList)) + if len(mentionedUsersList) > 0 { + useAddMentionsHook(message, mentionedUsersList) } - if len(notificationsForCRT.Desktop) != 0 { - message.Add("followers", model.ArrayToJSON(notificationsForCRT.Desktop)) + if len(notificationsForCRT.Desktop) > 0 { + useAddFollowersHook(message, notificationsForCRT.Desktop) } published, err := a.publishWebsocketEventForPermalinkPost(c, post, message) diff --git a/server/channels/app/notification_test.go b/server/channels/app/notification_test.go index 8b0fc1ec7b..cad0e63021 100644 --- a/server/channels/app/notification_test.go +++ b/server/channels/app/notification_test.go @@ -4,10 +4,14 @@ package app import ( + "encoding/json" "fmt" + "net/http" + "net/http/httptest" "testing" "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -15,6 +19,7 @@ import ( "github.com/mattermost/mattermost/server/public/shared/i18n" pUtils "github.com/mattermost/mattermost/server/public/utils" + "github.com/mattermost/mattermost/server/v8/channels/app/platform" "github.com/mattermost/mattermost/server/v8/channels/store" ) @@ -212,6 +217,314 @@ func TestSendNotifications(t *testing.T) { }) } +func TestSendNotifications_MentionsFollowers(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + th.AddUserToChannel(th.BasicUser2, th.BasicChannel) + + sender := th.CreateUser() + + th.LinkUserToTeam(sender, th.BasicTeam) + member := th.AddUserToChannel(sender, th.BasicChannel) + + t.Run("should inform each user if they were mentioned by a post", func(t *testing.T) { + messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "") + defer closeWS1() + + messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "") + defer closeWS2() + + // First post mentioning the whole channel + post := &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + Message: "@channel", + } + _, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false) + require.NoError(t, err) + + received1 := <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"]) + + received2 := <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"]) + + // Second post mentioning both users individually + post = &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + Message: fmt.Sprintf("@%s @%s", th.BasicUser.Username, th.BasicUser2.Username), + } + _, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false) + require.NoError(t, err) + + received1 = <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"]) + + received2 = <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"]) + + // Third post mentioning a single user + post = &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + Message: "@" + th.BasicUser.Username, + } + _, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false) + require.NoError(t, err) + + received1 = <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"]) + + received2 = <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assert.Nil(t, received2.GetData()["mentions"]) + }) + + t.Run("should inform each user in a group if they were mentioned by a post", func(t *testing.T) { + // Make the sender a channel_admin because that's needed for group mentions + originalRoles := member.Roles + member.Roles = "channel_user channel_admin" + _, appErr := th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, member.Roles) + require.Nil(t, appErr) + + defer func() { + th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, originalRoles) + }() + + th.App.Srv().SetLicense(getLicWithSkuShortName(model.LicenseShortSkuEnterprise)) + + // Make a group and add users + group := th.CreateGroup() + group.AllowReference = true + group, updateErr := th.App.UpdateGroup(group) + require.Nil(t, updateErr) + + _, upsertErr := th.App.UpsertGroupMember(group.Id, th.BasicUser.Id) + require.Nil(t, upsertErr) + _, upsertErr = th.App.UpsertGroupMember(group.Id, th.BasicUser2.Id) + require.Nil(t, upsertErr) + + // Set up the websockets + messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "") + defer closeWS1() + + messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "") + defer closeWS2() + + // Confirm permissions for group mentions are correct + post := &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + Message: "@" + *group.Name, + } + require.True(t, th.App.allowGroupMentions(th.Context, post)) + + // Test sending notifications + _, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false) + require.NoError(t, err) + + received1 := <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"]) + + received2 := <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"]) + }) + + t.Run("should inform each user if they are following a thread that was posted in", func(t *testing.T) { + t.Log("BasicUser ", th.BasicUser.Id) + t.Log("sender ", sender.Id) + messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "") + defer closeWS1() + + messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "") + defer closeWS2() + + // Reply to a post made by BasicUser + post := &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + RootId: th.BasicPost.Id, + Message: "This is a test", + } + + // Use CreatePost instead of SendNotifications here since we need that to set up some threads state + _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, false, false) + require.Nil(t, appErr) + + received1 := <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["followers"]) + + received2 := <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assert.Nil(t, received2.GetData()["followers"]) + }) + + t.Run("should not include broadcast hook information in messages sent to users", func(t *testing.T) { + messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "") + defer closeWS1() + + messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "") + defer closeWS2() + + // For a post mentioning only one user, nobody in the channel should receive information about the broadcast hooks + post := &model.Post{ + UserId: sender.Id, + ChannelId: th.BasicChannel.Id, + Message: fmt.Sprintf("@%s", th.BasicUser.Username), + } + _, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false) + require.NoError(t, err) + + received1 := <-messages1 + require.Equal(t, model.WebsocketEventPosted, received1.EventType()) + assert.Nil(t, received1.GetBroadcast().BroadcastHooks) + assert.Nil(t, received1.GetBroadcast().BroadcastHookArgs) + + received2 := <-messages2 + require.Equal(t, model.WebsocketEventPosted, received2.EventType()) + assert.Nil(t, received2.GetBroadcast().BroadcastHooks) + assert.Nil(t, received2.GetBroadcast().BroadcastHookArgs) + }) +} + +func assertUnmarshalsTo(t *testing.T, expected any, actual any) { + t.Helper() + + val, err := json.Marshal(expected) + require.NoError(t, err) + + assert.JSONEq(t, string(val), actual.(string)) +} + +func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectionID string) (chan *model.WebSocketEvent, func()) { + var session *model.Session + var server *httptest.Server + var webConn *platform.WebConn + + closeWS := func() { + if webConn != nil { + webConn.Close() + } + if server != nil { + server.Close() + } + if session != nil { + appErr := th.App.RevokeSession(th.Context, session) + require.Nil(t, appErr) + } + } + + // Create a session for the user's connection + var appErr *model.AppError + session, appErr = th.App.CreateSession(th.Context, &model.Session{ + UserId: userID, + }) + require.Nil(t, appErr) + + // Create a channel and an HTTP server to handle incoming WS events + messages := make(chan *model.WebSocketEvent) + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := &websocket.Upgrader{} + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Log("Received error when upgrading WebSocket connection", err) + return + } + defer c.Close() + + for { + _, reader, err := c.NextReader() + if err != nil { + t.Log("Received error when reading from WebSocket connection", err) + break + } + + msg, err := model.WebSocketEventFromJSON(reader) + if err != nil { + t.Log("Received error when decoding from WebSocket connection", err) + break + } + + messages <- msg + } + })) + + // Connect the WebSocket + d := websocket.Dialer{} + ws, _, err := d.Dial("ws://"+server.Listener.Addr().String(), nil) + require.NoError(t, err) + + // Register the WebSocket with the server as a WebConn + if connectionID == "" { + connectionID = model.NewId() + } + webConn = th.App.Srv().Platform().NewWebConn(&platform.WebConnConfig{ + WebSocket: ws, + Session: *session, + TFunc: i18n.IdentityTfunc(), + Locale: "en", + ConnectionID: connectionID, + }, th.App, th.App.Channels()) + th.App.Srv().Platform().HubRegister(webConn) + + // Start reading from it + go webConn.Pump() + + // Read the events which always occur at the start of a WebSocket connection + received := <-messages + assert.Equal(t, model.WebsocketEventHello, received.EventType()) + + received = <-messages + assert.Equal(t, model.WebsocketEventStatusChange, received.EventType()) + + return messages, closeWS +} + +func TestConnectFakeWebSocket(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + teamID := th.BasicTeam.Id + userID := th.BasicUser.Id + + messages, closeWS := connectFakeWebSocket(t, th, userID, "") + defer closeWS() + + msg := model.NewWebSocketEvent(model.WebsocketEventPosted, teamID, "", "", nil, "") + th.App.Publish(msg) + + msg = model.NewWebSocketEvent("test_event_with_data", "", "", userID, nil, "") + msg.Add("key1", "value1") + msg.Add("key2", 2) + msg.Add("key3", []string{"three", "trois"}) + th.App.Publish(msg) + + received := <-messages + require.Equal(t, model.WebsocketEventPosted, received.EventType()) + assert.Equal(t, teamID, received.GetBroadcast().TeamId) + + received = <-messages + require.Equal(t, "test_event_with_data", received.EventType()) + assert.Equal(t, userID, received.GetBroadcast().UserId) + // These type changes are annoying but unavoidable because event data is untyped + assert.Equal(t, map[string]any{ + "key1": "value1", + "key2": float64(2), + "key3": []any{"three", "trois"}, + }, received.GetData()) +} + func TestSendNotificationsWithManyUsers(t *testing.T) { th := Setup(t).InitBasic() defer th.TearDown() diff --git a/server/channels/app/platform/helper_test.go b/server/channels/app/platform/helper_test.go index 22796a93e1..bf118b4286 100644 --- a/server/channels/app/platform/helper_test.go +++ b/server/channels/app/platform/helper_test.go @@ -187,7 +187,7 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo th.Service.SetLicense(nil) } - err = th.Service.Start() + err = th.Service.Start(nil) if err != nil { panic(err) } diff --git a/server/channels/app/platform/service.go b/server/channels/app/platform/service.go index 23e694d595..71c24bd55a 100644 --- a/server/channels/app/platform/service.go +++ b/server/channels/app/platform/service.go @@ -344,8 +344,8 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) { return ps, nil } -func (ps *PlatformService) Start() error { - ps.hubStart() +func (ps *PlatformService) Start(broadcastHooks map[string]BroadcastHook) error { + ps.hubStart(broadcastHooks) ps.configListenerId = ps.AddConfigListener(func(_, _ *model.Config) { ps.regenerateClientConfig() diff --git a/server/channels/app/platform/web_broadcast_hook.go b/server/channels/app/platform/web_broadcast_hook.go new file mode 100644 index 0000000000..0cff77a196 --- /dev/null +++ b/server/channels/app/platform/web_broadcast_hook.go @@ -0,0 +1,87 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package platform + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" +) + +type BroadcastHook interface { + // Process takes a WebSocket event and modifies it in some way. It is passed a HookedWebSocketEvent which allows + // safe modification of the event. + Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error +} + +func (h *Hub) runBroadcastHooks(msg *model.WebSocketEvent, webConn *WebConn, hookIDs []string, hookArgs []map[string]any) *model.WebSocketEvent { + if len(hookIDs) == 0 { + return msg + } + + hookedEvent := MakeHookedWebSocketEvent(msg) + + for i, hookID := range hookIDs { + hook := h.broadcastHooks[hookID] + args := hookArgs[i] + if hook == nil { + mlog.Warn("runBroadcastHooks: Unable to find broadcast hook", mlog.String("hook_id", hookID)) + continue + } + + hook.Process(hookedEvent, webConn, args) + } + + return hookedEvent.Event() +} + +// HookedWebSocketEvent is a wrapper for model.WebSocketEvent that is intended to provide a similar interface, except +// it ensures the original WebSocket event is not modified. +type HookedWebSocketEvent struct { + original *model.WebSocketEvent + copy *model.WebSocketEvent +} + +func MakeHookedWebSocketEvent(event *model.WebSocketEvent) *HookedWebSocketEvent { + return &HookedWebSocketEvent{ + original: event, + } +} + +func (he *HookedWebSocketEvent) Add(key string, value any) { + he.copyIfNecessary() + + he.copy.Add(key, value) +} + +func (he *HookedWebSocketEvent) EventType() string { + if he.copy == nil { + return he.original.EventType() + } + + return he.copy.EventType() +} + +// Get returns a value from the WebSocket event data. You should never mutate a value returned by this method. +func (he *HookedWebSocketEvent) Get(key string) any { + if he.copy == nil { + return he.original.GetData()[key] + } + + return he.copy.GetData()[key] +} + +// copyIfNecessary should be called by any mutative method to ensure that the copy is instantiated. +func (he *HookedWebSocketEvent) copyIfNecessary() { + if he.copy == nil { + he.copy = he.original.RemovePrecomputedJSON() + } +} + +func (he *HookedWebSocketEvent) Event() *model.WebSocketEvent { + if he.copy == nil { + return he.original + } + + return he.copy +} diff --git a/server/channels/app/platform/web_broadcast_hook_test.go b/server/channels/app/platform/web_broadcast_hook_test.go new file mode 100644 index 0000000000..556286437a --- /dev/null +++ b/server/channels/app/platform/web_broadcast_hook_test.go @@ -0,0 +1,206 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package platform + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const broadcastTest = "test_broadcast_hook" + +type testBroadcastHook struct{} + +func (h *testBroadcastHook) Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error { + if args["makes_changes"].(bool) { + changesMade, _ := msg.Get("changes_made").(int) + msg.Add("changes_made", changesMade+1) + } + + return nil +} + +func TestRunBroadcastHooks(t *testing.T) { + hub := &Hub{ + broadcastHooks: map[string]BroadcastHook{ + broadcastTest: &testBroadcastHook{}, + }, + } + webConn := &WebConn{} + + t.Run("should not allocate a new object when no hooks are passed", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + result := hub.runBroadcastHooks(event, webConn, nil, nil) + + assert.Same(t, event, result) + }) + + t.Run("should not allocate a new object when a hook is not making changes", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + hookIDs := []string{ + broadcastTest, + } + hookArgs := []map[string]any{ + { + "makes_changes": false, + }, + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + assert.Same(t, event, result) + }) + + t.Run("should allocate a new object and remove when a hook makes changes", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + hookIDs := []string{ + broadcastTest, + } + hookArgs := []map[string]any{ + { + "makes_changes": true, + }, + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + assert.NotSame(t, event, result) + assert.NotSame(t, event.GetData(), result.GetData()) + assert.Equal(t, map[string]any{}, event.GetData()) + assert.Equal(t, result.GetData(), map[string]any{ + "changes_made": 1, + }) + }) + + t.Run("should not allocate a new object when multiple hooks are not making changes", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + hookIDs := []string{ + broadcastTest, + broadcastTest, + broadcastTest, + } + hookArgs := []map[string]any{ + { + "makes_changes": false, + }, + { + "makes_changes": false, + }, + { + "makes_changes": false, + }, + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + assert.Same(t, event, result) + }) + + t.Run("should be able to make changes from only one of make hooks", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + var hookIDs []string + var hookArgs []map[string]any + for i := 0; i < 10; i++ { + hookIDs = append(hookIDs, broadcastTest) + hookArgs = append(hookArgs, map[string]any{ + "makes_changes": i == 6, + }) + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + assert.NotSame(t, event, result) + assert.NotSame(t, event.GetData(), result.GetData()) + assert.Equal(t, event.GetData(), map[string]any{}) + assert.Equal(t, result.GetData(), map[string]any{ + "changes_made": 1, + }) + }) + + t.Run("should be able to make changes from multiple hooks", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + + var hookIDs []string + var hookArgs []map[string]any + for i := 0; i < 10; i++ { + hookIDs = append(hookIDs, broadcastTest) + hookArgs = append(hookArgs, map[string]any{ + "makes_changes": true, + }) + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + assert.NotSame(t, event, result) + assert.NotSame(t, event.GetData(), result.GetData()) + assert.Equal(t, event.GetData(), map[string]any{}) + assert.Equal(t, result.GetData(), map[string]any{ + "changes_made": 10, + }) + }) + + t.Run("should not remove precomputed JSON when a hook doesn't make changes", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + event = event.PrecomputeJSON() + + // Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again + originalJSON, _ := event.ToJSON() + event.Add("data", 1234) + eventJSON, _ := event.ToJSON() + require.Equal(t, string(originalJSON), string(eventJSON)) + + hookIDs := []string{ + broadcastTest, + } + hookArgs := []map[string]any{ + { + "makes_changes": false, + }, + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + eventJSON, _ = event.ToJSON() + assert.Equal(t, string(originalJSON), string(eventJSON)) + + resultJSON, _ := result.ToJSON() + assert.Equal(t, originalJSON, resultJSON) + }) + + t.Run("should remove precomputed JSON when a hook makes changes", func(t *testing.T) { + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + event = event.PrecomputeJSON() + + // Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again + originalJSON, _ := event.ToJSON() + event.Add("data", 1234) + eventJSON, _ := event.ToJSON() + require.Equal(t, originalJSON, eventJSON) + + hookIDs := []string{ + broadcastTest, + } + hookArgs := []map[string]any{ + { + "makes_changes": true, + }, + } + + result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs) + + eventJSON, _ = event.ToJSON() + assert.Equal(t, string(originalJSON), string(eventJSON)) + + resultJSON, _ := result.ToJSON() + assert.NotEqual(t, originalJSON, resultJSON) + }) +} diff --git a/server/channels/app/platform/web_hub.go b/server/channels/app/platform/web_hub.go index 62110cef0d..337e283155 100644 --- a/server/channels/app/platform/web_hub.go +++ b/server/channels/app/platform/web_hub.go @@ -70,6 +70,7 @@ type Hub struct { explicitStop bool checkRegistered chan *webConnSessionMessage checkConn chan *webConnCheckMessage + broadcastHooks map[string]BroadcastHook } // newWebHub creates a new Hub. @@ -90,7 +91,7 @@ func newWebHub(ps *PlatformService) *Hub { } // hubStart starts all the hubs. -func (ps *PlatformService) hubStart() { +func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) { // Total number of hubs is twice the number of CPUs. numberOfHubs := runtime.NumCPU() * 2 ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs)) @@ -100,6 +101,7 @@ func (ps *PlatformService) hubStart() { for i := 0; i < numberOfHubs; i++ { hubs[i] = newWebHub(ps) hubs[i].connectionIndex = i + hubs[i].broadcastHooks = broadcastHooks hubs[i].Start() } // Assigning to the hubs slice without any mutex is fine because it is only assigned once @@ -492,14 +494,19 @@ func (h *Hub) Start() { if metrics := h.platform.metricsIFace; metrics != nil { metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1) } + + // Remove the broadcast hook information before precomputing the JSON so that those aren't included in it + msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks() + msg = msg.PrecomputeJSON() + broadcast := func(webConn *WebConn) { if !connIndex.Has(webConn) { return } if webConn.ShouldSendEvent(msg) { select { - case webConn.send <- msg: + case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs): default: // Don't log the warning if it's an inactive connection. if webConn.active.Load() { diff --git a/server/channels/app/platform/web_hub_test.go b/server/channels/app/platform/web_hub_test.go index 2b0cc0b840..6e7b215edb 100644 --- a/server/channels/app/platform/web_hub_test.go +++ b/server/channels/app/platform/web_hub_test.go @@ -4,6 +4,7 @@ package platform import ( + "bytes" "encoding/json" "net" "net/http" @@ -69,7 +70,7 @@ func TestHubStopWithMultipleConnections(t *testing.T) { }) require.NoError(t, err) - th.Service.Start() + th.Service.Start(nil) wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session) wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session) wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session) @@ -93,7 +94,7 @@ func TestHubStopRaceCondition(t *testing.T) { }) require.NoError(t, err) - th.Service.Start() + th.Service.Start(nil) wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session) defer wc1.Close() @@ -476,7 +477,7 @@ func TestHubIsRegistered(t *testing.T) { s := httptest.NewServer(dummyWebsocketHandler(t)) defer s.Close() - th.Service.Start() + th.Service.Start(nil) wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session) wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session) wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session) @@ -583,7 +584,7 @@ func BenchmarkGetHubForUserId(b *testing.B) { th := Setup(b).InitBasic() defer th.TearDown() - th.Service.Start() + th.Service.Start(nil) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -618,3 +619,54 @@ func TestClusterBroadcast(t *testing.T) { require.NoError(t, err) require.Equal(t, clusterEvent.Broadcast, broadcast) } + +func TestClusterBroadcastHooks(t *testing.T) { + t.Run("should send broadcast hook information across cluster", func(t *testing.T) { + testCluster := &testlib.FakeClusterInterface{} + + th := SetupWithCluster(t, testCluster) + defer th.TearDown() + + hookID := broadcastTest + hookArgs := map[string]any{ + "makes_changes": true, + } + + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + event.GetBroadcast().AddHook(hookID, hookArgs) + + th.Service.Publish(event) + + received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data)) + + require.NoError(t, err) + assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks) + assert.Equal(t, []map[string]any{hookArgs}, received.GetBroadcast().BroadcastHookArgs) + }) + + t.Run("should not preserve type information for args", func(t *testing.T) { + // This behaviour isn't ideal, but this test confirms that it hasn't changed + testCluster := &testlib.FakeClusterInterface{} + + th := SetupWithCluster(t, testCluster) + defer th.TearDown() + + hookID := "test_broadcast_hook_with_args" + hookArgs := map[string]any{ + "user": &model.User{Id: "user1"}, + "array": []string{"a", "b", "c"}, + } + + event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "") + event.GetBroadcast().AddHook(hookID, hookArgs) + + th.Service.Publish(event) + + received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data)) + + require.NoError(t, err) + assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks) + assert.IsType(t, map[string]any{}, received.GetBroadcast().BroadcastHookArgs[0]["user"]) + assert.IsType(t, []any{}, received.GetBroadcast().BroadcastHookArgs[0]["array"]) + }) +} diff --git a/server/channels/app/server.go b/server/channels/app/server.go index 999391e5c2..a3b5c4b7ac 100644 --- a/server/channels/app/server.go +++ b/server/channels/app/server.go @@ -270,7 +270,7 @@ func NewServer(options ...Option) (*Server, error) { // It is important to initialize the hub only after the global logger is set // to avoid race conditions while logging from inside the hub. // Step 4: Start platform - s.platform.Start() + s.platform.Start(s.makeBroadcastHooks()) // NOTE: There should be no call to App.Srv().Channels() before step 5 is done // otherwise it will throw a panic. diff --git a/server/channels/app/web_broadcast_hooks.go b/server/channels/app/web_broadcast_hooks.go new file mode 100644 index 0000000000..f17ee9d740 --- /dev/null +++ b/server/channels/app/web_broadcast_hooks.go @@ -0,0 +1,96 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "encoding/json" + "fmt" + + "github.com/mattermost/mattermost/server/public/model" + pUtils "github.com/mattermost/mattermost/server/public/utils" + "github.com/mattermost/mattermost/server/v8/channels/app/platform" + "github.com/pkg/errors" +) + +const ( + broadcastAddMentions = "add_mentions" + broadcastAddFollowers = "add_followers" +) + +func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook { + return map[string]platform.BroadcastHook{ + broadcastAddMentions: &addMentionsBroadcastHook{}, + broadcastAddFollowers: &addFollowersBroadcastHook{}, + } +} + +type addMentionsBroadcastHook struct{} + +func (h *addMentionsBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error { + mentions, err := getTypedArg[model.StringArray](args, "mentions") + if err != nil { + return errors.Wrap(err, "Invalid mentions value passed to addMentionsBroadcastHook") + } + + if len(mentions) > 0 && pUtils.Contains[string](mentions, webConn.UserId) { + // Note that the client expects this field to be stringified + msg.Add("mentions", model.ArrayToJSON([]string{webConn.UserId})) + } + + return nil +} + +func useAddMentionsHook(message *model.WebSocketEvent, mentionedUsers model.StringArray) { + message.GetBroadcast().AddHook(broadcastAddMentions, map[string]any{ + "mentions": mentionedUsers, + }) +} + +type addFollowersBroadcastHook struct{} + +func (h *addFollowersBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error { + followers, err := getTypedArg[model.StringArray](args, "followers") + if err != nil { + return errors.Wrap(err, "Invalid followers value passed to addFollowersBroadcastHook") + } + + if len(followers) > 0 && pUtils.Contains[string](followers, webConn.UserId) { + // Note that the client expects this field to be stringified + msg.Add("followers", model.ArrayToJSON([]string{webConn.UserId})) + } + + return nil +} + +func useAddFollowersHook(message *model.WebSocketEvent, followers model.StringArray) { + message.GetBroadcast().AddHook(broadcastAddFollowers, map[string]any{ + "followers": followers, + }) +} + +// getTypedArg returns a correctly typed hook argument with the given key, reinterpreting the type using JSON encoding +// if necessary. This is needed because broadcast hook args are JSON encoded in a multi-server environment, and any +// type information is lost because those types aren't known at decode time. +func getTypedArg[T any](args map[string]any, key string) (T, error) { + var value T + + untyped, ok := args[key] + if !ok { + return value, fmt.Errorf("No argument found with key: %s", key) + } + + // If the value is already correct, just return it + if typed, ok := untyped.(T); ok { + return typed, nil + } + + // Marshal and unmarshal the data with the correct typing information + buf, err := json.Marshal(untyped) + if err != nil { + return value, err + } + + err = json.Unmarshal(buf, &value) + return value, err +} diff --git a/server/channels/app/web_broadcast_hooks_test.go b/server/channels/app/web_broadcast_hooks_test.go new file mode 100644 index 0000000000..a4c88d8cda --- /dev/null +++ b/server/channels/app/web_broadcast_hooks_test.go @@ -0,0 +1,114 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/app/platform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddMentionsHook_Process(t *testing.T) { + hook := &addMentionsBroadcastHook{} + + userID := model.NewId() + otherUserID := model.NewId() + + webConn := &platform.WebConn{ + UserId: userID, + } + + t.Run("should add a mentions entry for the current user", func(t *testing.T) { + msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")) + + require.Nil(t, msg.Event().GetData()["mentions"]) + + hook.Process(msg, webConn, map[string]any{ + "mentions": model.StringArray{userID}, + }) + + assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"]) + assert.Nil(t, msg.Event().GetData()["followers"]) + }) + + t.Run("should not add a mentions entry for another user", func(t *testing.T) { + msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")) + + require.Nil(t, msg.Event().GetData()["mentions"]) + + hook.Process(msg, webConn, map[string]any{ + "mentions": model.StringArray{otherUserID}, + }) + + assert.Nil(t, msg.Event().GetData()["mentions"]) + }) +} + +func TestAddFollowersHook_Process(t *testing.T) { + hook := &addFollowersBroadcastHook{} + + userID := model.NewId() + otherUserID := model.NewId() + + webConn := &platform.WebConn{ + UserId: userID, + } + + t.Run("should add a followers entry for the current user", func(t *testing.T) { + msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")) + + require.Nil(t, msg.Event().GetData()["followers"]) + + hook.Process(msg, webConn, map[string]any{ + "followers": model.StringArray{userID}, + }) + + assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"]) + }) + + t.Run("should not add a followers entry for another user", func(t *testing.T) { + msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")) + + require.Nil(t, msg.Event().GetData()["followers"]) + + hook.Process(msg, webConn, map[string]any{ + "followers": model.StringArray{otherUserID}, + }) + + assert.Nil(t, msg.Event().GetData()["followers"]) + }) +} + +func TestAddMentionsAndAddFollowersHooks(t *testing.T) { + addMentionsHook := &addMentionsBroadcastHook{} + addFollowersHook := &addFollowersBroadcastHook{} + + userID := model.NewId() + + webConn := &platform.WebConn{ + UserId: userID, + } + + msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")) + + originalData := msg.Event().GetData() + + require.Nil(t, originalData["mentions"]) + require.Nil(t, originalData["followers"]) + + addMentionsHook.Process(msg, webConn, map[string]any{ + "mentions": model.StringArray{userID}, + }) + addFollowersHook.Process(msg, webConn, map[string]any{ + "followers": model.StringArray{userID}, + }) + + t.Run("should be able to add both mentions and followers to a single event", func(t *testing.T) { + assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"]) + assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"]) + }) +} diff --git a/server/public/model/websocket_message.go b/server/public/model/websocket_message.go index 80119ef1ca..0f5ee4a7cf 100644 --- a/server/public/model/websocket_message.go +++ b/server/public/model/websocket_message.go @@ -104,6 +104,17 @@ type WebsocketBroadcast struct { // ReliableClusterSend indicates whether or not the message should // be sent through the cluster using the reliable, TCP backed channel. ReliableClusterSend bool `json:"-"` + + // BroadcastHooks is a slice of hooks IDs used to process events before sending them on individual connections. The + // IDs should be understood by the WebSocket code. + // + // This field should never be sent to the client. + BroadcastHooks []string `json:"broadcast_hooks,omitempty"` + // BroadcastHookArgs is a slice of named arguments for each hook invocation. The index of each entry corresponds to + // the index of a hook ID in BroadcastHooks + // + // This field should never be sent to the client. + BroadcastHookArgs []map[string]any `json:"broadcast_hook_args,omitempty"` } func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast { @@ -124,10 +135,17 @@ func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast { c.OmitConnectionId = wb.OmitConnectionId c.ContainsSanitizedData = wb.ContainsSanitizedData c.ContainsSensitiveData = wb.ContainsSensitiveData + c.BroadcastHooks = wb.BroadcastHooks + c.BroadcastHookArgs = wb.BroadcastHookArgs return &c } +func (wb *WebsocketBroadcast) AddHook(hookID string, hookArgs map[string]any) { + wb.BroadcastHooks = append(wb.BroadcastHooks, hookID) + wb.BroadcastHookArgs = append(wb.BroadcastHookArgs, hookArgs) +} + type precomputedWebSocketEventJSON struct { Event json.RawMessage Data json.RawMessage @@ -190,6 +208,32 @@ func (ev *WebSocketEvent) PrecomputeJSON() *WebSocketEvent { return evCopy } +func (ev *WebSocketEvent) RemovePrecomputedJSON() *WebSocketEvent { + evCopy := ev.DeepCopy() + evCopy.precomputedJSON = nil + return evCopy +} + +// WithoutBroadcastHooks gets the broadcast hook information from a WebSocketEvent and returns the event without that. +// If the event has broadcast hooks, a copy of the event is returned. Otherwise, the original event is returned. This +// is intended to be called before the event is sent to the client. +func (ev *WebSocketEvent) WithoutBroadcastHooks() (*WebSocketEvent, []string, []map[string]any) { + hooks := ev.broadcast.BroadcastHooks + hookArgs := ev.broadcast.BroadcastHookArgs + + if len(hooks) == 0 && len(hookArgs) == 0 { + return ev, hooks, hookArgs + } + + evCopy := ev.Copy() + evCopy.broadcast = ev.broadcast.copy() + + evCopy.broadcast.BroadcastHooks = nil + evCopy.broadcast.BroadcastHookArgs = nil + + return evCopy, hooks, hookArgs +} + func (ev *WebSocketEvent) Add(key string, value any) { ev.data[key] = value } @@ -219,17 +263,9 @@ func (ev *WebSocketEvent) Copy() *WebSocketEvent { } func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent { - var dataCopy map[string]any - if ev.data != nil { - dataCopy = make(map[string]any, len(ev.data)) - for k, v := range ev.data { - dataCopy[k] = v - } - } - evCopy := &WebSocketEvent{ event: ev.event, - data: dataCopy, + data: copyMap(ev.data), broadcast: ev.broadcast.copy(), sequence: ev.sequence, precomputedJSON: ev.precomputedJSON.copy(), @@ -237,6 +273,14 @@ func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent { return evCopy } +func copyMap[K comparable, V any](m map[K]V) map[K]V { + dataCopy := make(map[K]V, len(m)) + for k, v := range m { + dataCopy[k] = v + } + return dataCopy +} + func (ev *WebSocketEvent) GetData() map[string]any { return ev.data }