mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-54238 Add WebSocket broadcast hook and don't broadcast when other users were mentioned (#24641)
* MM-54238 Initial implementation * MM-54238 Move websocket hook into app package * MM-54238 Add tests for mentions in posted websocket messages * Fix styling * Fix other styling * Idiomatic ID naming for new code * Fix more styles * Separate hooks to add mentions and followers * Improved error handling for invalid types in hooks * Rename HasChanges to ShouldProcess * Pass broadcast hooks through hubStart * Add test helper for asserting json unmarshaling * Fix missing arguments in tests * Ensure broadcast hooks are sent across the cluster and not to users * Ensure tests actually cover following a post * Fix code broken by merge * Go vet again... * Deep copy event before processing it with hooks * Replace RemoveBroadcastHooks with WithoutBroadcastHooks * Address feedback * Add helper to fix type information for hook args * Wrap WebSocketEvent and simplify BroadcastHook * Address feedback * Address feedback
This commit is contained in:
parent
5e62ba8ccc
commit
ef66f7beab
@ -517,12 +517,12 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(mentionedUsersList) != 0 {
|
if len(mentionedUsersList) > 0 {
|
||||||
message.Add("mentions", model.ArrayToJSON(mentionedUsersList))
|
useAddMentionsHook(message, mentionedUsersList)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(notificationsForCRT.Desktop) != 0 {
|
if len(notificationsForCRT.Desktop) > 0 {
|
||||||
message.Add("followers", model.ArrayToJSON(notificationsForCRT.Desktop))
|
useAddFollowersHook(message, notificationsForCRT.Desktop)
|
||||||
}
|
}
|
||||||
|
|
||||||
published, err := a.publishWebsocketEventForPermalinkPost(c, post, message)
|
published, err := a.publishWebsocketEventForPermalinkPost(c, post, message)
|
||||||
|
@ -4,10 +4,14 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@ -15,6 +19,7 @@ import (
|
|||||||
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
||||||
pUtils "github.com/mattermost/mattermost/server/public/utils"
|
pUtils "github.com/mattermost/mattermost/server/public/utils"
|
||||||
|
|
||||||
|
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -212,6 +217,314 @@ func TestSendNotifications(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSendNotifications_MentionsFollowers(t *testing.T) {
|
||||||
|
th := Setup(t).InitBasic()
|
||||||
|
defer th.TearDown()
|
||||||
|
|
||||||
|
th.AddUserToChannel(th.BasicUser2, th.BasicChannel)
|
||||||
|
|
||||||
|
sender := th.CreateUser()
|
||||||
|
|
||||||
|
th.LinkUserToTeam(sender, th.BasicTeam)
|
||||||
|
member := th.AddUserToChannel(sender, th.BasicChannel)
|
||||||
|
|
||||||
|
t.Run("should inform each user if they were mentioned by a post", func(t *testing.T) {
|
||||||
|
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||||
|
defer closeWS1()
|
||||||
|
|
||||||
|
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||||
|
defer closeWS2()
|
||||||
|
|
||||||
|
// First post mentioning the whole channel
|
||||||
|
post := &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
Message: "@channel",
|
||||||
|
}
|
||||||
|
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received1 := <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||||
|
|
||||||
|
received2 := <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||||
|
|
||||||
|
// Second post mentioning both users individually
|
||||||
|
post = &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
Message: fmt.Sprintf("@%s @%s", th.BasicUser.Username, th.BasicUser2.Username),
|
||||||
|
}
|
||||||
|
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received1 = <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||||
|
|
||||||
|
received2 = <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||||
|
|
||||||
|
// Third post mentioning a single user
|
||||||
|
post = &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
Message: "@" + th.BasicUser.Username,
|
||||||
|
}
|
||||||
|
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received1 = <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||||
|
|
||||||
|
received2 = <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assert.Nil(t, received2.GetData()["mentions"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should inform each user in a group if they were mentioned by a post", func(t *testing.T) {
|
||||||
|
// Make the sender a channel_admin because that's needed for group mentions
|
||||||
|
originalRoles := member.Roles
|
||||||
|
member.Roles = "channel_user channel_admin"
|
||||||
|
_, appErr := th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, member.Roles)
|
||||||
|
require.Nil(t, appErr)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, originalRoles)
|
||||||
|
}()
|
||||||
|
|
||||||
|
th.App.Srv().SetLicense(getLicWithSkuShortName(model.LicenseShortSkuEnterprise))
|
||||||
|
|
||||||
|
// Make a group and add users
|
||||||
|
group := th.CreateGroup()
|
||||||
|
group.AllowReference = true
|
||||||
|
group, updateErr := th.App.UpdateGroup(group)
|
||||||
|
require.Nil(t, updateErr)
|
||||||
|
|
||||||
|
_, upsertErr := th.App.UpsertGroupMember(group.Id, th.BasicUser.Id)
|
||||||
|
require.Nil(t, upsertErr)
|
||||||
|
_, upsertErr = th.App.UpsertGroupMember(group.Id, th.BasicUser2.Id)
|
||||||
|
require.Nil(t, upsertErr)
|
||||||
|
|
||||||
|
// Set up the websockets
|
||||||
|
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||||
|
defer closeWS1()
|
||||||
|
|
||||||
|
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||||
|
defer closeWS2()
|
||||||
|
|
||||||
|
// Confirm permissions for group mentions are correct
|
||||||
|
post := &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
Message: "@" + *group.Name,
|
||||||
|
}
|
||||||
|
require.True(t, th.App.allowGroupMentions(th.Context, post))
|
||||||
|
|
||||||
|
// Test sending notifications
|
||||||
|
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received1 := <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||||
|
|
||||||
|
received2 := <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should inform each user if they are following a thread that was posted in", func(t *testing.T) {
|
||||||
|
t.Log("BasicUser ", th.BasicUser.Id)
|
||||||
|
t.Log("sender ", sender.Id)
|
||||||
|
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||||
|
defer closeWS1()
|
||||||
|
|
||||||
|
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||||
|
defer closeWS2()
|
||||||
|
|
||||||
|
// Reply to a post made by BasicUser
|
||||||
|
post := &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
RootId: th.BasicPost.Id,
|
||||||
|
Message: "This is a test",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use CreatePost instead of SendNotifications here since we need that to set up some threads state
|
||||||
|
_, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, false, false)
|
||||||
|
require.Nil(t, appErr)
|
||||||
|
|
||||||
|
received1 := <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["followers"])
|
||||||
|
|
||||||
|
received2 := <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assert.Nil(t, received2.GetData()["followers"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not include broadcast hook information in messages sent to users", func(t *testing.T) {
|
||||||
|
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||||
|
defer closeWS1()
|
||||||
|
|
||||||
|
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||||
|
defer closeWS2()
|
||||||
|
|
||||||
|
// For a post mentioning only one user, nobody in the channel should receive information about the broadcast hooks
|
||||||
|
post := &model.Post{
|
||||||
|
UserId: sender.Id,
|
||||||
|
ChannelId: th.BasicChannel.Id,
|
||||||
|
Message: fmt.Sprintf("@%s", th.BasicUser.Username),
|
||||||
|
}
|
||||||
|
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
received1 := <-messages1
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||||
|
assert.Nil(t, received1.GetBroadcast().BroadcastHooks)
|
||||||
|
assert.Nil(t, received1.GetBroadcast().BroadcastHookArgs)
|
||||||
|
|
||||||
|
received2 := <-messages2
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||||
|
assert.Nil(t, received2.GetBroadcast().BroadcastHooks)
|
||||||
|
assert.Nil(t, received2.GetBroadcast().BroadcastHookArgs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertUnmarshalsTo(t *testing.T, expected any, actual any) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
val, err := json.Marshal(expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.JSONEq(t, string(val), actual.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectionID string) (chan *model.WebSocketEvent, func()) {
|
||||||
|
var session *model.Session
|
||||||
|
var server *httptest.Server
|
||||||
|
var webConn *platform.WebConn
|
||||||
|
|
||||||
|
closeWS := func() {
|
||||||
|
if webConn != nil {
|
||||||
|
webConn.Close()
|
||||||
|
}
|
||||||
|
if server != nil {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
appErr := th.App.RevokeSession(th.Context, session)
|
||||||
|
require.Nil(t, appErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a session for the user's connection
|
||||||
|
var appErr *model.AppError
|
||||||
|
session, appErr = th.App.CreateSession(th.Context, &model.Session{
|
||||||
|
UserId: userID,
|
||||||
|
})
|
||||||
|
require.Nil(t, appErr)
|
||||||
|
|
||||||
|
// Create a channel and an HTTP server to handle incoming WS events
|
||||||
|
messages := make(chan *model.WebSocketEvent)
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
upgrader := &websocket.Upgrader{}
|
||||||
|
|
||||||
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("Received error when upgrading WebSocket connection", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, reader, err := c.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
t.Log("Received error when reading from WebSocket connection", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := model.WebSocketEventFromJSON(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("Received error when decoding from WebSocket connection", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
messages <- msg
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Connect the WebSocket
|
||||||
|
d := websocket.Dialer{}
|
||||||
|
ws, _, err := d.Dial("ws://"+server.Listener.Addr().String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Register the WebSocket with the server as a WebConn
|
||||||
|
if connectionID == "" {
|
||||||
|
connectionID = model.NewId()
|
||||||
|
}
|
||||||
|
webConn = th.App.Srv().Platform().NewWebConn(&platform.WebConnConfig{
|
||||||
|
WebSocket: ws,
|
||||||
|
Session: *session,
|
||||||
|
TFunc: i18n.IdentityTfunc(),
|
||||||
|
Locale: "en",
|
||||||
|
ConnectionID: connectionID,
|
||||||
|
}, th.App, th.App.Channels())
|
||||||
|
th.App.Srv().Platform().HubRegister(webConn)
|
||||||
|
|
||||||
|
// Start reading from it
|
||||||
|
go webConn.Pump()
|
||||||
|
|
||||||
|
// Read the events which always occur at the start of a WebSocket connection
|
||||||
|
received := <-messages
|
||||||
|
assert.Equal(t, model.WebsocketEventHello, received.EventType())
|
||||||
|
|
||||||
|
received = <-messages
|
||||||
|
assert.Equal(t, model.WebsocketEventStatusChange, received.EventType())
|
||||||
|
|
||||||
|
return messages, closeWS
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectFakeWebSocket(t *testing.T) {
|
||||||
|
th := Setup(t).InitBasic()
|
||||||
|
defer th.TearDown()
|
||||||
|
|
||||||
|
teamID := th.BasicTeam.Id
|
||||||
|
userID := th.BasicUser.Id
|
||||||
|
|
||||||
|
messages, closeWS := connectFakeWebSocket(t, th, userID, "")
|
||||||
|
defer closeWS()
|
||||||
|
|
||||||
|
msg := model.NewWebSocketEvent(model.WebsocketEventPosted, teamID, "", "", nil, "")
|
||||||
|
th.App.Publish(msg)
|
||||||
|
|
||||||
|
msg = model.NewWebSocketEvent("test_event_with_data", "", "", userID, nil, "")
|
||||||
|
msg.Add("key1", "value1")
|
||||||
|
msg.Add("key2", 2)
|
||||||
|
msg.Add("key3", []string{"three", "trois"})
|
||||||
|
th.App.Publish(msg)
|
||||||
|
|
||||||
|
received := <-messages
|
||||||
|
require.Equal(t, model.WebsocketEventPosted, received.EventType())
|
||||||
|
assert.Equal(t, teamID, received.GetBroadcast().TeamId)
|
||||||
|
|
||||||
|
received = <-messages
|
||||||
|
require.Equal(t, "test_event_with_data", received.EventType())
|
||||||
|
assert.Equal(t, userID, received.GetBroadcast().UserId)
|
||||||
|
// These type changes are annoying but unavoidable because event data is untyped
|
||||||
|
assert.Equal(t, map[string]any{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": float64(2),
|
||||||
|
"key3": []any{"three", "trois"},
|
||||||
|
}, received.GetData())
|
||||||
|
}
|
||||||
|
|
||||||
func TestSendNotificationsWithManyUsers(t *testing.T) {
|
func TestSendNotificationsWithManyUsers(t *testing.T) {
|
||||||
th := Setup(t).InitBasic()
|
th := Setup(t).InitBasic()
|
||||||
defer th.TearDown()
|
defer th.TearDown()
|
||||||
|
@ -187,7 +187,7 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
|
|||||||
th.Service.SetLicense(nil)
|
th.Service.SetLicense(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = th.Service.Start()
|
err = th.Service.Start(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -344,8 +344,8 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) {
|
|||||||
return ps, nil
|
return ps, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *PlatformService) Start() error {
|
func (ps *PlatformService) Start(broadcastHooks map[string]BroadcastHook) error {
|
||||||
ps.hubStart()
|
ps.hubStart(broadcastHooks)
|
||||||
|
|
||||||
ps.configListenerId = ps.AddConfigListener(func(_, _ *model.Config) {
|
ps.configListenerId = ps.AddConfigListener(func(_, _ *model.Config) {
|
||||||
ps.regenerateClientConfig()
|
ps.regenerateClientConfig()
|
||||||
|
87
server/channels/app/platform/web_broadcast_hook.go
Normal file
87
server/channels/app/platform/web_broadcast_hook.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package platform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BroadcastHook interface {
|
||||||
|
// Process takes a WebSocket event and modifies it in some way. It is passed a HookedWebSocketEvent which allows
|
||||||
|
// safe modification of the event.
|
||||||
|
Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) runBroadcastHooks(msg *model.WebSocketEvent, webConn *WebConn, hookIDs []string, hookArgs []map[string]any) *model.WebSocketEvent {
|
||||||
|
if len(hookIDs) == 0 {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
hookedEvent := MakeHookedWebSocketEvent(msg)
|
||||||
|
|
||||||
|
for i, hookID := range hookIDs {
|
||||||
|
hook := h.broadcastHooks[hookID]
|
||||||
|
args := hookArgs[i]
|
||||||
|
if hook == nil {
|
||||||
|
mlog.Warn("runBroadcastHooks: Unable to find broadcast hook", mlog.String("hook_id", hookID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hook.Process(hookedEvent, webConn, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hookedEvent.Event()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookedWebSocketEvent is a wrapper for model.WebSocketEvent that is intended to provide a similar interface, except
|
||||||
|
// it ensures the original WebSocket event is not modified.
|
||||||
|
type HookedWebSocketEvent struct {
|
||||||
|
original *model.WebSocketEvent
|
||||||
|
copy *model.WebSocketEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeHookedWebSocketEvent(event *model.WebSocketEvent) *HookedWebSocketEvent {
|
||||||
|
return &HookedWebSocketEvent{
|
||||||
|
original: event,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (he *HookedWebSocketEvent) Add(key string, value any) {
|
||||||
|
he.copyIfNecessary()
|
||||||
|
|
||||||
|
he.copy.Add(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (he *HookedWebSocketEvent) EventType() string {
|
||||||
|
if he.copy == nil {
|
||||||
|
return he.original.EventType()
|
||||||
|
}
|
||||||
|
|
||||||
|
return he.copy.EventType()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value from the WebSocket event data. You should never mutate a value returned by this method.
|
||||||
|
func (he *HookedWebSocketEvent) Get(key string) any {
|
||||||
|
if he.copy == nil {
|
||||||
|
return he.original.GetData()[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
return he.copy.GetData()[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyIfNecessary should be called by any mutative method to ensure that the copy is instantiated.
|
||||||
|
func (he *HookedWebSocketEvent) copyIfNecessary() {
|
||||||
|
if he.copy == nil {
|
||||||
|
he.copy = he.original.RemovePrecomputedJSON()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (he *HookedWebSocketEvent) Event() *model.WebSocketEvent {
|
||||||
|
if he.copy == nil {
|
||||||
|
return he.original
|
||||||
|
}
|
||||||
|
|
||||||
|
return he.copy
|
||||||
|
}
|
206
server/channels/app/platform/web_broadcast_hook_test.go
Normal file
206
server/channels/app/platform/web_broadcast_hook_test.go
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package platform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const broadcastTest = "test_broadcast_hook"
|
||||||
|
|
||||||
|
type testBroadcastHook struct{}
|
||||||
|
|
||||||
|
func (h *testBroadcastHook) Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error {
|
||||||
|
if args["makes_changes"].(bool) {
|
||||||
|
changesMade, _ := msg.Get("changes_made").(int)
|
||||||
|
msg.Add("changes_made", changesMade+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunBroadcastHooks(t *testing.T) {
|
||||||
|
hub := &Hub{
|
||||||
|
broadcastHooks: map[string]BroadcastHook{
|
||||||
|
broadcastTest: &testBroadcastHook{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
webConn := &WebConn{}
|
||||||
|
|
||||||
|
t.Run("should not allocate a new object when no hooks are passed", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, nil, nil)
|
||||||
|
|
||||||
|
assert.Same(t, event, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not allocate a new object when a hook is not making changes", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
hookIDs := []string{
|
||||||
|
broadcastTest,
|
||||||
|
}
|
||||||
|
hookArgs := []map[string]any{
|
||||||
|
{
|
||||||
|
"makes_changes": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
assert.Same(t, event, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should allocate a new object and remove when a hook makes changes", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
hookIDs := []string{
|
||||||
|
broadcastTest,
|
||||||
|
}
|
||||||
|
hookArgs := []map[string]any{
|
||||||
|
{
|
||||||
|
"makes_changes": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
assert.NotSame(t, event, result)
|
||||||
|
assert.NotSame(t, event.GetData(), result.GetData())
|
||||||
|
assert.Equal(t, map[string]any{}, event.GetData())
|
||||||
|
assert.Equal(t, result.GetData(), map[string]any{
|
||||||
|
"changes_made": 1,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not allocate a new object when multiple hooks are not making changes", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
hookIDs := []string{
|
||||||
|
broadcastTest,
|
||||||
|
broadcastTest,
|
||||||
|
broadcastTest,
|
||||||
|
}
|
||||||
|
hookArgs := []map[string]any{
|
||||||
|
{
|
||||||
|
"makes_changes": false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"makes_changes": false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"makes_changes": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
assert.Same(t, event, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should be able to make changes from only one of make hooks", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
var hookIDs []string
|
||||||
|
var hookArgs []map[string]any
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
hookIDs = append(hookIDs, broadcastTest)
|
||||||
|
hookArgs = append(hookArgs, map[string]any{
|
||||||
|
"makes_changes": i == 6,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
assert.NotSame(t, event, result)
|
||||||
|
assert.NotSame(t, event.GetData(), result.GetData())
|
||||||
|
assert.Equal(t, event.GetData(), map[string]any{})
|
||||||
|
assert.Equal(t, result.GetData(), map[string]any{
|
||||||
|
"changes_made": 1,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should be able to make changes from multiple hooks", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
|
||||||
|
var hookIDs []string
|
||||||
|
var hookArgs []map[string]any
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
hookIDs = append(hookIDs, broadcastTest)
|
||||||
|
hookArgs = append(hookArgs, map[string]any{
|
||||||
|
"makes_changes": true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
assert.NotSame(t, event, result)
|
||||||
|
assert.NotSame(t, event.GetData(), result.GetData())
|
||||||
|
assert.Equal(t, event.GetData(), map[string]any{})
|
||||||
|
assert.Equal(t, result.GetData(), map[string]any{
|
||||||
|
"changes_made": 10,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not remove precomputed JSON when a hook doesn't make changes", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
event = event.PrecomputeJSON()
|
||||||
|
|
||||||
|
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
|
||||||
|
originalJSON, _ := event.ToJSON()
|
||||||
|
event.Add("data", 1234)
|
||||||
|
eventJSON, _ := event.ToJSON()
|
||||||
|
require.Equal(t, string(originalJSON), string(eventJSON))
|
||||||
|
|
||||||
|
hookIDs := []string{
|
||||||
|
broadcastTest,
|
||||||
|
}
|
||||||
|
hookArgs := []map[string]any{
|
||||||
|
{
|
||||||
|
"makes_changes": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
eventJSON, _ = event.ToJSON()
|
||||||
|
assert.Equal(t, string(originalJSON), string(eventJSON))
|
||||||
|
|
||||||
|
resultJSON, _ := result.ToJSON()
|
||||||
|
assert.Equal(t, originalJSON, resultJSON)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should remove precomputed JSON when a hook makes changes", func(t *testing.T) {
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
event = event.PrecomputeJSON()
|
||||||
|
|
||||||
|
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
|
||||||
|
originalJSON, _ := event.ToJSON()
|
||||||
|
event.Add("data", 1234)
|
||||||
|
eventJSON, _ := event.ToJSON()
|
||||||
|
require.Equal(t, originalJSON, eventJSON)
|
||||||
|
|
||||||
|
hookIDs := []string{
|
||||||
|
broadcastTest,
|
||||||
|
}
|
||||||
|
hookArgs := []map[string]any{
|
||||||
|
{
|
||||||
|
"makes_changes": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||||
|
|
||||||
|
eventJSON, _ = event.ToJSON()
|
||||||
|
assert.Equal(t, string(originalJSON), string(eventJSON))
|
||||||
|
|
||||||
|
resultJSON, _ := result.ToJSON()
|
||||||
|
assert.NotEqual(t, originalJSON, resultJSON)
|
||||||
|
})
|
||||||
|
}
|
@ -70,6 +70,7 @@ type Hub struct {
|
|||||||
explicitStop bool
|
explicitStop bool
|
||||||
checkRegistered chan *webConnSessionMessage
|
checkRegistered chan *webConnSessionMessage
|
||||||
checkConn chan *webConnCheckMessage
|
checkConn chan *webConnCheckMessage
|
||||||
|
broadcastHooks map[string]BroadcastHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// newWebHub creates a new Hub.
|
// newWebHub creates a new Hub.
|
||||||
@ -90,7 +91,7 @@ func newWebHub(ps *PlatformService) *Hub {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hubStart starts all the hubs.
|
// hubStart starts all the hubs.
|
||||||
func (ps *PlatformService) hubStart() {
|
func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) {
|
||||||
// Total number of hubs is twice the number of CPUs.
|
// Total number of hubs is twice the number of CPUs.
|
||||||
numberOfHubs := runtime.NumCPU() * 2
|
numberOfHubs := runtime.NumCPU() * 2
|
||||||
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
|
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
|
||||||
@ -100,6 +101,7 @@ func (ps *PlatformService) hubStart() {
|
|||||||
for i := 0; i < numberOfHubs; i++ {
|
for i := 0; i < numberOfHubs; i++ {
|
||||||
hubs[i] = newWebHub(ps)
|
hubs[i] = newWebHub(ps)
|
||||||
hubs[i].connectionIndex = i
|
hubs[i].connectionIndex = i
|
||||||
|
hubs[i].broadcastHooks = broadcastHooks
|
||||||
hubs[i].Start()
|
hubs[i].Start()
|
||||||
}
|
}
|
||||||
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
|
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
|
||||||
@ -492,14 +494,19 @@ func (h *Hub) Start() {
|
|||||||
if metrics := h.platform.metricsIFace; metrics != nil {
|
if metrics := h.platform.metricsIFace; metrics != nil {
|
||||||
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
|
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove the broadcast hook information before precomputing the JSON so that those aren't included in it
|
||||||
|
msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks()
|
||||||
|
|
||||||
msg = msg.PrecomputeJSON()
|
msg = msg.PrecomputeJSON()
|
||||||
|
|
||||||
broadcast := func(webConn *WebConn) {
|
broadcast := func(webConn *WebConn) {
|
||||||
if !connIndex.Has(webConn) {
|
if !connIndex.Has(webConn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if webConn.ShouldSendEvent(msg) {
|
if webConn.ShouldSendEvent(msg) {
|
||||||
select {
|
select {
|
||||||
case webConn.send <- msg:
|
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
|
||||||
default:
|
default:
|
||||||
// Don't log the warning if it's an inactive connection.
|
// Don't log the warning if it's an inactive connection.
|
||||||
if webConn.active.Load() {
|
if webConn.active.Load() {
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package platform
|
package platform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -69,7 +70,7 @@ func TestHubStopWithMultipleConnections(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
th.Service.Start()
|
th.Service.Start(nil)
|
||||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
@ -93,7 +94,7 @@ func TestHubStopRaceCondition(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
th.Service.Start()
|
th.Service.Start(nil)
|
||||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
defer wc1.Close()
|
defer wc1.Close()
|
||||||
|
|
||||||
@ -476,7 +477,7 @@ func TestHubIsRegistered(t *testing.T) {
|
|||||||
s := httptest.NewServer(dummyWebsocketHandler(t))
|
s := httptest.NewServer(dummyWebsocketHandler(t))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
th.Service.Start()
|
th.Service.Start(nil)
|
||||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||||
@ -583,7 +584,7 @@ func BenchmarkGetHubForUserId(b *testing.B) {
|
|||||||
th := Setup(b).InitBasic()
|
th := Setup(b).InitBasic()
|
||||||
defer th.TearDown()
|
defer th.TearDown()
|
||||||
|
|
||||||
th.Service.Start()
|
th.Service.Start(nil)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@ -618,3 +619,54 @@ func TestClusterBroadcast(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, clusterEvent.Broadcast, broadcast)
|
require.Equal(t, clusterEvent.Broadcast, broadcast)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClusterBroadcastHooks(t *testing.T) {
|
||||||
|
t.Run("should send broadcast hook information across cluster", func(t *testing.T) {
|
||||||
|
testCluster := &testlib.FakeClusterInterface{}
|
||||||
|
|
||||||
|
th := SetupWithCluster(t, testCluster)
|
||||||
|
defer th.TearDown()
|
||||||
|
|
||||||
|
hookID := broadcastTest
|
||||||
|
hookArgs := map[string]any{
|
||||||
|
"makes_changes": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
event.GetBroadcast().AddHook(hookID, hookArgs)
|
||||||
|
|
||||||
|
th.Service.Publish(event)
|
||||||
|
|
||||||
|
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
|
||||||
|
assert.Equal(t, []map[string]any{hookArgs}, received.GetBroadcast().BroadcastHookArgs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not preserve type information for args", func(t *testing.T) {
|
||||||
|
// This behaviour isn't ideal, but this test confirms that it hasn't changed
|
||||||
|
testCluster := &testlib.FakeClusterInterface{}
|
||||||
|
|
||||||
|
th := SetupWithCluster(t, testCluster)
|
||||||
|
defer th.TearDown()
|
||||||
|
|
||||||
|
hookID := "test_broadcast_hook_with_args"
|
||||||
|
hookArgs := map[string]any{
|
||||||
|
"user": &model.User{Id: "user1"},
|
||||||
|
"array": []string{"a", "b", "c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||||
|
event.GetBroadcast().AddHook(hookID, hookArgs)
|
||||||
|
|
||||||
|
th.Service.Publish(event)
|
||||||
|
|
||||||
|
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
|
||||||
|
assert.IsType(t, map[string]any{}, received.GetBroadcast().BroadcastHookArgs[0]["user"])
|
||||||
|
assert.IsType(t, []any{}, received.GetBroadcast().BroadcastHookArgs[0]["array"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -270,7 +270,7 @@ func NewServer(options ...Option) (*Server, error) {
|
|||||||
// It is important to initialize the hub only after the global logger is set
|
// It is important to initialize the hub only after the global logger is set
|
||||||
// to avoid race conditions while logging from inside the hub.
|
// to avoid race conditions while logging from inside the hub.
|
||||||
// Step 4: Start platform
|
// Step 4: Start platform
|
||||||
s.platform.Start()
|
s.platform.Start(s.makeBroadcastHooks())
|
||||||
|
|
||||||
// NOTE: There should be no call to App.Srv().Channels() before step 5 is done
|
// NOTE: There should be no call to App.Srv().Channels() before step 5 is done
|
||||||
// otherwise it will throw a panic.
|
// otherwise it will throw a panic.
|
||||||
|
96
server/channels/app/web_broadcast_hooks.go
Normal file
96
server/channels/app/web_broadcast_hooks.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
pUtils "github.com/mattermost/mattermost/server/public/utils"
|
||||||
|
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
broadcastAddMentions = "add_mentions"
|
||||||
|
broadcastAddFollowers = "add_followers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook {
|
||||||
|
return map[string]platform.BroadcastHook{
|
||||||
|
broadcastAddMentions: &addMentionsBroadcastHook{},
|
||||||
|
broadcastAddFollowers: &addFollowersBroadcastHook{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type addMentionsBroadcastHook struct{}
|
||||||
|
|
||||||
|
func (h *addMentionsBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
|
||||||
|
mentions, err := getTypedArg[model.StringArray](args, "mentions")
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Invalid mentions value passed to addMentionsBroadcastHook")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(mentions) > 0 && pUtils.Contains[string](mentions, webConn.UserId) {
|
||||||
|
// Note that the client expects this field to be stringified
|
||||||
|
msg.Add("mentions", model.ArrayToJSON([]string{webConn.UserId}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func useAddMentionsHook(message *model.WebSocketEvent, mentionedUsers model.StringArray) {
|
||||||
|
message.GetBroadcast().AddHook(broadcastAddMentions, map[string]any{
|
||||||
|
"mentions": mentionedUsers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type addFollowersBroadcastHook struct{}
|
||||||
|
|
||||||
|
func (h *addFollowersBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
|
||||||
|
followers, err := getTypedArg[model.StringArray](args, "followers")
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Invalid followers value passed to addFollowersBroadcastHook")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(followers) > 0 && pUtils.Contains[string](followers, webConn.UserId) {
|
||||||
|
// Note that the client expects this field to be stringified
|
||||||
|
msg.Add("followers", model.ArrayToJSON([]string{webConn.UserId}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func useAddFollowersHook(message *model.WebSocketEvent, followers model.StringArray) {
|
||||||
|
message.GetBroadcast().AddHook(broadcastAddFollowers, map[string]any{
|
||||||
|
"followers": followers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTypedArg returns a correctly typed hook argument with the given key, reinterpreting the type using JSON encoding
|
||||||
|
// if necessary. This is needed because broadcast hook args are JSON encoded in a multi-server environment, and any
|
||||||
|
// type information is lost because those types aren't known at decode time.
|
||||||
|
func getTypedArg[T any](args map[string]any, key string) (T, error) {
|
||||||
|
var value T
|
||||||
|
|
||||||
|
untyped, ok := args[key]
|
||||||
|
if !ok {
|
||||||
|
return value, fmt.Errorf("No argument found with key: %s", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value is already correct, just return it
|
||||||
|
if typed, ok := untyped.(T); ok {
|
||||||
|
return typed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal and unmarshal the data with the correct typing information
|
||||||
|
buf, err := json.Marshal(untyped)
|
||||||
|
if err != nil {
|
||||||
|
return value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(buf, &value)
|
||||||
|
return value, err
|
||||||
|
}
|
114
server/channels/app/web_broadcast_hooks_test.go
Normal file
114
server/channels/app/web_broadcast_hooks_test.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||||
|
// See LICENSE.txt for license information.
|
||||||
|
|
||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
|
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddMentionsHook_Process(t *testing.T) {
|
||||||
|
hook := &addMentionsBroadcastHook{}
|
||||||
|
|
||||||
|
userID := model.NewId()
|
||||||
|
otherUserID := model.NewId()
|
||||||
|
|
||||||
|
webConn := &platform.WebConn{
|
||||||
|
UserId: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should add a mentions entry for the current user", func(t *testing.T) {
|
||||||
|
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||||
|
|
||||||
|
require.Nil(t, msg.Event().GetData()["mentions"])
|
||||||
|
|
||||||
|
hook.Process(msg, webConn, map[string]any{
|
||||||
|
"mentions": model.StringArray{userID},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
|
||||||
|
assert.Nil(t, msg.Event().GetData()["followers"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not add a mentions entry for another user", func(t *testing.T) {
|
||||||
|
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||||
|
|
||||||
|
require.Nil(t, msg.Event().GetData()["mentions"])
|
||||||
|
|
||||||
|
hook.Process(msg, webConn, map[string]any{
|
||||||
|
"mentions": model.StringArray{otherUserID},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Nil(t, msg.Event().GetData()["mentions"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddFollowersHook_Process(t *testing.T) {
|
||||||
|
hook := &addFollowersBroadcastHook{}
|
||||||
|
|
||||||
|
userID := model.NewId()
|
||||||
|
otherUserID := model.NewId()
|
||||||
|
|
||||||
|
webConn := &platform.WebConn{
|
||||||
|
UserId: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should add a followers entry for the current user", func(t *testing.T) {
|
||||||
|
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||||
|
|
||||||
|
require.Nil(t, msg.Event().GetData()["followers"])
|
||||||
|
|
||||||
|
hook.Process(msg, webConn, map[string]any{
|
||||||
|
"followers": model.StringArray{userID},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should not add a followers entry for another user", func(t *testing.T) {
|
||||||
|
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||||
|
|
||||||
|
require.Nil(t, msg.Event().GetData()["followers"])
|
||||||
|
|
||||||
|
hook.Process(msg, webConn, map[string]any{
|
||||||
|
"followers": model.StringArray{otherUserID},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Nil(t, msg.Event().GetData()["followers"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMentionsAndAddFollowersHooks(t *testing.T) {
|
||||||
|
addMentionsHook := &addMentionsBroadcastHook{}
|
||||||
|
addFollowersHook := &addFollowersBroadcastHook{}
|
||||||
|
|
||||||
|
userID := model.NewId()
|
||||||
|
|
||||||
|
webConn := &platform.WebConn{
|
||||||
|
UserId: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||||
|
|
||||||
|
originalData := msg.Event().GetData()
|
||||||
|
|
||||||
|
require.Nil(t, originalData["mentions"])
|
||||||
|
require.Nil(t, originalData["followers"])
|
||||||
|
|
||||||
|
addMentionsHook.Process(msg, webConn, map[string]any{
|
||||||
|
"mentions": model.StringArray{userID},
|
||||||
|
})
|
||||||
|
addFollowersHook.Process(msg, webConn, map[string]any{
|
||||||
|
"followers": model.StringArray{userID},
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should be able to add both mentions and followers to a single event", func(t *testing.T) {
|
||||||
|
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
|
||||||
|
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
|
||||||
|
})
|
||||||
|
}
|
@ -104,6 +104,17 @@ type WebsocketBroadcast struct {
|
|||||||
// ReliableClusterSend indicates whether or not the message should
|
// ReliableClusterSend indicates whether or not the message should
|
||||||
// be sent through the cluster using the reliable, TCP backed channel.
|
// be sent through the cluster using the reliable, TCP backed channel.
|
||||||
ReliableClusterSend bool `json:"-"`
|
ReliableClusterSend bool `json:"-"`
|
||||||
|
|
||||||
|
// BroadcastHooks is a slice of hooks IDs used to process events before sending them on individual connections. The
|
||||||
|
// IDs should be understood by the WebSocket code.
|
||||||
|
//
|
||||||
|
// This field should never be sent to the client.
|
||||||
|
BroadcastHooks []string `json:"broadcast_hooks,omitempty"`
|
||||||
|
// BroadcastHookArgs is a slice of named arguments for each hook invocation. The index of each entry corresponds to
|
||||||
|
// the index of a hook ID in BroadcastHooks
|
||||||
|
//
|
||||||
|
// This field should never be sent to the client.
|
||||||
|
BroadcastHookArgs []map[string]any `json:"broadcast_hook_args,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
|
func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
|
||||||
@ -124,10 +135,17 @@ func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
|
|||||||
c.OmitConnectionId = wb.OmitConnectionId
|
c.OmitConnectionId = wb.OmitConnectionId
|
||||||
c.ContainsSanitizedData = wb.ContainsSanitizedData
|
c.ContainsSanitizedData = wb.ContainsSanitizedData
|
||||||
c.ContainsSensitiveData = wb.ContainsSensitiveData
|
c.ContainsSensitiveData = wb.ContainsSensitiveData
|
||||||
|
c.BroadcastHooks = wb.BroadcastHooks
|
||||||
|
c.BroadcastHookArgs = wb.BroadcastHookArgs
|
||||||
|
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (wb *WebsocketBroadcast) AddHook(hookID string, hookArgs map[string]any) {
|
||||||
|
wb.BroadcastHooks = append(wb.BroadcastHooks, hookID)
|
||||||
|
wb.BroadcastHookArgs = append(wb.BroadcastHookArgs, hookArgs)
|
||||||
|
}
|
||||||
|
|
||||||
type precomputedWebSocketEventJSON struct {
|
type precomputedWebSocketEventJSON struct {
|
||||||
Event json.RawMessage
|
Event json.RawMessage
|
||||||
Data json.RawMessage
|
Data json.RawMessage
|
||||||
@ -190,6 +208,32 @@ func (ev *WebSocketEvent) PrecomputeJSON() *WebSocketEvent {
|
|||||||
return evCopy
|
return evCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ev *WebSocketEvent) RemovePrecomputedJSON() *WebSocketEvent {
|
||||||
|
evCopy := ev.DeepCopy()
|
||||||
|
evCopy.precomputedJSON = nil
|
||||||
|
return evCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithoutBroadcastHooks gets the broadcast hook information from a WebSocketEvent and returns the event without that.
|
||||||
|
// If the event has broadcast hooks, a copy of the event is returned. Otherwise, the original event is returned. This
|
||||||
|
// is intended to be called before the event is sent to the client.
|
||||||
|
func (ev *WebSocketEvent) WithoutBroadcastHooks() (*WebSocketEvent, []string, []map[string]any) {
|
||||||
|
hooks := ev.broadcast.BroadcastHooks
|
||||||
|
hookArgs := ev.broadcast.BroadcastHookArgs
|
||||||
|
|
||||||
|
if len(hooks) == 0 && len(hookArgs) == 0 {
|
||||||
|
return ev, hooks, hookArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
evCopy := ev.Copy()
|
||||||
|
evCopy.broadcast = ev.broadcast.copy()
|
||||||
|
|
||||||
|
evCopy.broadcast.BroadcastHooks = nil
|
||||||
|
evCopy.broadcast.BroadcastHookArgs = nil
|
||||||
|
|
||||||
|
return evCopy, hooks, hookArgs
|
||||||
|
}
|
||||||
|
|
||||||
func (ev *WebSocketEvent) Add(key string, value any) {
|
func (ev *WebSocketEvent) Add(key string, value any) {
|
||||||
ev.data[key] = value
|
ev.data[key] = value
|
||||||
}
|
}
|
||||||
@ -219,17 +263,9 @@ func (ev *WebSocketEvent) Copy() *WebSocketEvent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
|
func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
|
||||||
var dataCopy map[string]any
|
|
||||||
if ev.data != nil {
|
|
||||||
dataCopy = make(map[string]any, len(ev.data))
|
|
||||||
for k, v := range ev.data {
|
|
||||||
dataCopy[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
evCopy := &WebSocketEvent{
|
evCopy := &WebSocketEvent{
|
||||||
event: ev.event,
|
event: ev.event,
|
||||||
data: dataCopy,
|
data: copyMap(ev.data),
|
||||||
broadcast: ev.broadcast.copy(),
|
broadcast: ev.broadcast.copy(),
|
||||||
sequence: ev.sequence,
|
sequence: ev.sequence,
|
||||||
precomputedJSON: ev.precomputedJSON.copy(),
|
precomputedJSON: ev.precomputedJSON.copy(),
|
||||||
@ -237,6 +273,14 @@ func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
|
|||||||
return evCopy
|
return evCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyMap[K comparable, V any](m map[K]V) map[K]V {
|
||||||
|
dataCopy := make(map[K]V, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
dataCopy[k] = v
|
||||||
|
}
|
||||||
|
return dataCopy
|
||||||
|
}
|
||||||
|
|
||||||
func (ev *WebSocketEvent) GetData() map[string]any {
|
func (ev *WebSocketEvent) GetData() map[string]any {
|
||||||
return ev.data
|
return ev.data
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user