mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-54238 Add WebSocket broadcast hook and don't broadcast when other users were mentioned (#24641)
* MM-54238 Initial implementation * MM-54238 Move websocket hook into app package * MM-54238 Add tests for mentions in posted websocket messages * Fix styling * Fix other styling * Idiomatic ID naming for new code * Fix more styles * Separate hooks to add mentions and followers * Improved error handling for invalid types in hooks * Rename HasChanges to ShouldProcess * Pass broadcast hooks through hubStart * Add test helper for asserting json unmarshaling * Fix missing arguments in tests * Ensure broadcast hooks are sent across the cluster and not to users * Ensure tests actually cover following a post * Fix code broken by merge * Go vet again... * Deep copy event before processing it with hooks * Replace RemoveBroadcastHooks with WithoutBroadcastHooks * Address feedback * Add helper to fix type information for hook args * Wrap WebSocketEvent and simplify BroadcastHook * Address feedback * Address feedback
This commit is contained in:
parent
5e62ba8ccc
commit
ef66f7beab
@ -517,12 +517,12 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea
|
||||
}
|
||||
}
|
||||
|
||||
if len(mentionedUsersList) != 0 {
|
||||
message.Add("mentions", model.ArrayToJSON(mentionedUsersList))
|
||||
if len(mentionedUsersList) > 0 {
|
||||
useAddMentionsHook(message, mentionedUsersList)
|
||||
}
|
||||
|
||||
if len(notificationsForCRT.Desktop) != 0 {
|
||||
message.Add("followers", model.ArrayToJSON(notificationsForCRT.Desktop))
|
||||
if len(notificationsForCRT.Desktop) > 0 {
|
||||
useAddFollowersHook(message, notificationsForCRT.Desktop)
|
||||
}
|
||||
|
||||
published, err := a.publishWebsocketEventForPermalinkPost(c, post, message)
|
||||
|
@ -4,10 +4,14 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -15,6 +19,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
||||
pUtils "github.com/mattermost/mattermost/server/public/utils"
|
||||
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||
)
|
||||
|
||||
@ -212,6 +217,314 @@ func TestSendNotifications(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendNotifications_MentionsFollowers(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
th.AddUserToChannel(th.BasicUser2, th.BasicChannel)
|
||||
|
||||
sender := th.CreateUser()
|
||||
|
||||
th.LinkUserToTeam(sender, th.BasicTeam)
|
||||
member := th.AddUserToChannel(sender, th.BasicChannel)
|
||||
|
||||
t.Run("should inform each user if they were mentioned by a post", func(t *testing.T) {
|
||||
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||
defer closeWS1()
|
||||
|
||||
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||
defer closeWS2()
|
||||
|
||||
// First post mentioning the whole channel
|
||||
post := &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: "@channel",
|
||||
}
|
||||
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
received1 := <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||
|
||||
received2 := <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||
|
||||
// Second post mentioning both users individually
|
||||
post = &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: fmt.Sprintf("@%s @%s", th.BasicUser.Username, th.BasicUser2.Username),
|
||||
}
|
||||
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
received1 = <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||
|
||||
received2 = <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||
|
||||
// Third post mentioning a single user
|
||||
post = &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: "@" + th.BasicUser.Username,
|
||||
}
|
||||
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
received1 = <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||
|
||||
received2 = <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assert.Nil(t, received2.GetData()["mentions"])
|
||||
})
|
||||
|
||||
t.Run("should inform each user in a group if they were mentioned by a post", func(t *testing.T) {
|
||||
// Make the sender a channel_admin because that's needed for group mentions
|
||||
originalRoles := member.Roles
|
||||
member.Roles = "channel_user channel_admin"
|
||||
_, appErr := th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, member.Roles)
|
||||
require.Nil(t, appErr)
|
||||
|
||||
defer func() {
|
||||
th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, originalRoles)
|
||||
}()
|
||||
|
||||
th.App.Srv().SetLicense(getLicWithSkuShortName(model.LicenseShortSkuEnterprise))
|
||||
|
||||
// Make a group and add users
|
||||
group := th.CreateGroup()
|
||||
group.AllowReference = true
|
||||
group, updateErr := th.App.UpdateGroup(group)
|
||||
require.Nil(t, updateErr)
|
||||
|
||||
_, upsertErr := th.App.UpsertGroupMember(group.Id, th.BasicUser.Id)
|
||||
require.Nil(t, upsertErr)
|
||||
_, upsertErr = th.App.UpsertGroupMember(group.Id, th.BasicUser2.Id)
|
||||
require.Nil(t, upsertErr)
|
||||
|
||||
// Set up the websockets
|
||||
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||
defer closeWS1()
|
||||
|
||||
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||
defer closeWS2()
|
||||
|
||||
// Confirm permissions for group mentions are correct
|
||||
post := &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: "@" + *group.Name,
|
||||
}
|
||||
require.True(t, th.App.allowGroupMentions(th.Context, post))
|
||||
|
||||
// Test sending notifications
|
||||
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
received1 := <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
|
||||
|
||||
received2 := <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
|
||||
})
|
||||
|
||||
t.Run("should inform each user if they are following a thread that was posted in", func(t *testing.T) {
|
||||
t.Log("BasicUser ", th.BasicUser.Id)
|
||||
t.Log("sender ", sender.Id)
|
||||
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||
defer closeWS1()
|
||||
|
||||
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||
defer closeWS2()
|
||||
|
||||
// Reply to a post made by BasicUser
|
||||
post := &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
RootId: th.BasicPost.Id,
|
||||
Message: "This is a test",
|
||||
}
|
||||
|
||||
// Use CreatePost instead of SendNotifications here since we need that to set up some threads state
|
||||
_, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, false, false)
|
||||
require.Nil(t, appErr)
|
||||
|
||||
received1 := <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["followers"])
|
||||
|
||||
received2 := <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assert.Nil(t, received2.GetData()["followers"])
|
||||
})
|
||||
|
||||
t.Run("should not include broadcast hook information in messages sent to users", func(t *testing.T) {
|
||||
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
|
||||
defer closeWS1()
|
||||
|
||||
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
|
||||
defer closeWS2()
|
||||
|
||||
// For a post mentioning only one user, nobody in the channel should receive information about the broadcast hooks
|
||||
post := &model.Post{
|
||||
UserId: sender.Id,
|
||||
ChannelId: th.BasicChannel.Id,
|
||||
Message: fmt.Sprintf("@%s", th.BasicUser.Username),
|
||||
}
|
||||
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
received1 := <-messages1
|
||||
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
|
||||
assert.Nil(t, received1.GetBroadcast().BroadcastHooks)
|
||||
assert.Nil(t, received1.GetBroadcast().BroadcastHookArgs)
|
||||
|
||||
received2 := <-messages2
|
||||
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
|
||||
assert.Nil(t, received2.GetBroadcast().BroadcastHooks)
|
||||
assert.Nil(t, received2.GetBroadcast().BroadcastHookArgs)
|
||||
})
|
||||
}
|
||||
|
||||
func assertUnmarshalsTo(t *testing.T, expected any, actual any) {
|
||||
t.Helper()
|
||||
|
||||
val, err := json.Marshal(expected)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.JSONEq(t, string(val), actual.(string))
|
||||
}
|
||||
|
||||
func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectionID string) (chan *model.WebSocketEvent, func()) {
|
||||
var session *model.Session
|
||||
var server *httptest.Server
|
||||
var webConn *platform.WebConn
|
||||
|
||||
closeWS := func() {
|
||||
if webConn != nil {
|
||||
webConn.Close()
|
||||
}
|
||||
if server != nil {
|
||||
server.Close()
|
||||
}
|
||||
if session != nil {
|
||||
appErr := th.App.RevokeSession(th.Context, session)
|
||||
require.Nil(t, appErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a session for the user's connection
|
||||
var appErr *model.AppError
|
||||
session, appErr = th.App.CreateSession(th.Context, &model.Session{
|
||||
UserId: userID,
|
||||
})
|
||||
require.Nil(t, appErr)
|
||||
|
||||
// Create a channel and an HTTP server to handle incoming WS events
|
||||
messages := make(chan *model.WebSocketEvent)
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := &websocket.Upgrader{}
|
||||
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Log("Received error when upgrading WebSocket connection", err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
_, reader, err := c.NextReader()
|
||||
if err != nil {
|
||||
t.Log("Received error when reading from WebSocket connection", err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := model.WebSocketEventFromJSON(reader)
|
||||
if err != nil {
|
||||
t.Log("Received error when decoding from WebSocket connection", err)
|
||||
break
|
||||
}
|
||||
|
||||
messages <- msg
|
||||
}
|
||||
}))
|
||||
|
||||
// Connect the WebSocket
|
||||
d := websocket.Dialer{}
|
||||
ws, _, err := d.Dial("ws://"+server.Listener.Addr().String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register the WebSocket with the server as a WebConn
|
||||
if connectionID == "" {
|
||||
connectionID = model.NewId()
|
||||
}
|
||||
webConn = th.App.Srv().Platform().NewWebConn(&platform.WebConnConfig{
|
||||
WebSocket: ws,
|
||||
Session: *session,
|
||||
TFunc: i18n.IdentityTfunc(),
|
||||
Locale: "en",
|
||||
ConnectionID: connectionID,
|
||||
}, th.App, th.App.Channels())
|
||||
th.App.Srv().Platform().HubRegister(webConn)
|
||||
|
||||
// Start reading from it
|
||||
go webConn.Pump()
|
||||
|
||||
// Read the events which always occur at the start of a WebSocket connection
|
||||
received := <-messages
|
||||
assert.Equal(t, model.WebsocketEventHello, received.EventType())
|
||||
|
||||
received = <-messages
|
||||
assert.Equal(t, model.WebsocketEventStatusChange, received.EventType())
|
||||
|
||||
return messages, closeWS
|
||||
}
|
||||
|
||||
func TestConnectFakeWebSocket(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
teamID := th.BasicTeam.Id
|
||||
userID := th.BasicUser.Id
|
||||
|
||||
messages, closeWS := connectFakeWebSocket(t, th, userID, "")
|
||||
defer closeWS()
|
||||
|
||||
msg := model.NewWebSocketEvent(model.WebsocketEventPosted, teamID, "", "", nil, "")
|
||||
th.App.Publish(msg)
|
||||
|
||||
msg = model.NewWebSocketEvent("test_event_with_data", "", "", userID, nil, "")
|
||||
msg.Add("key1", "value1")
|
||||
msg.Add("key2", 2)
|
||||
msg.Add("key3", []string{"three", "trois"})
|
||||
th.App.Publish(msg)
|
||||
|
||||
received := <-messages
|
||||
require.Equal(t, model.WebsocketEventPosted, received.EventType())
|
||||
assert.Equal(t, teamID, received.GetBroadcast().TeamId)
|
||||
|
||||
received = <-messages
|
||||
require.Equal(t, "test_event_with_data", received.EventType())
|
||||
assert.Equal(t, userID, received.GetBroadcast().UserId)
|
||||
// These type changes are annoying but unavoidable because event data is untyped
|
||||
assert.Equal(t, map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": float64(2),
|
||||
"key3": []any{"three", "trois"},
|
||||
}, received.GetData())
|
||||
}
|
||||
|
||||
func TestSendNotificationsWithManyUsers(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
@ -187,7 +187,7 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
|
||||
th.Service.SetLicense(nil)
|
||||
}
|
||||
|
||||
err = th.Service.Start()
|
||||
err = th.Service.Start(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -344,8 +344,8 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) {
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func (ps *PlatformService) Start() error {
|
||||
ps.hubStart()
|
||||
func (ps *PlatformService) Start(broadcastHooks map[string]BroadcastHook) error {
|
||||
ps.hubStart(broadcastHooks)
|
||||
|
||||
ps.configListenerId = ps.AddConfigListener(func(_, _ *model.Config) {
|
||||
ps.regenerateClientConfig()
|
||||
|
87
server/channels/app/platform/web_broadcast_hook.go
Normal file
87
server/channels/app/platform/web_broadcast_hook.go
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package platform
|
||||
|
||||
import (
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
)
|
||||
|
||||
type BroadcastHook interface {
|
||||
// Process takes a WebSocket event and modifies it in some way. It is passed a HookedWebSocketEvent which allows
|
||||
// safe modification of the event.
|
||||
Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error
|
||||
}
|
||||
|
||||
func (h *Hub) runBroadcastHooks(msg *model.WebSocketEvent, webConn *WebConn, hookIDs []string, hookArgs []map[string]any) *model.WebSocketEvent {
|
||||
if len(hookIDs) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
hookedEvent := MakeHookedWebSocketEvent(msg)
|
||||
|
||||
for i, hookID := range hookIDs {
|
||||
hook := h.broadcastHooks[hookID]
|
||||
args := hookArgs[i]
|
||||
if hook == nil {
|
||||
mlog.Warn("runBroadcastHooks: Unable to find broadcast hook", mlog.String("hook_id", hookID))
|
||||
continue
|
||||
}
|
||||
|
||||
hook.Process(hookedEvent, webConn, args)
|
||||
}
|
||||
|
||||
return hookedEvent.Event()
|
||||
}
|
||||
|
||||
// HookedWebSocketEvent is a wrapper for model.WebSocketEvent that is intended to provide a similar interface, except
|
||||
// it ensures the original WebSocket event is not modified.
|
||||
type HookedWebSocketEvent struct {
|
||||
original *model.WebSocketEvent
|
||||
copy *model.WebSocketEvent
|
||||
}
|
||||
|
||||
func MakeHookedWebSocketEvent(event *model.WebSocketEvent) *HookedWebSocketEvent {
|
||||
return &HookedWebSocketEvent{
|
||||
original: event,
|
||||
}
|
||||
}
|
||||
|
||||
func (he *HookedWebSocketEvent) Add(key string, value any) {
|
||||
he.copyIfNecessary()
|
||||
|
||||
he.copy.Add(key, value)
|
||||
}
|
||||
|
||||
func (he *HookedWebSocketEvent) EventType() string {
|
||||
if he.copy == nil {
|
||||
return he.original.EventType()
|
||||
}
|
||||
|
||||
return he.copy.EventType()
|
||||
}
|
||||
|
||||
// Get returns a value from the WebSocket event data. You should never mutate a value returned by this method.
|
||||
func (he *HookedWebSocketEvent) Get(key string) any {
|
||||
if he.copy == nil {
|
||||
return he.original.GetData()[key]
|
||||
}
|
||||
|
||||
return he.copy.GetData()[key]
|
||||
}
|
||||
|
||||
// copyIfNecessary should be called by any mutative method to ensure that the copy is instantiated.
|
||||
func (he *HookedWebSocketEvent) copyIfNecessary() {
|
||||
if he.copy == nil {
|
||||
he.copy = he.original.RemovePrecomputedJSON()
|
||||
}
|
||||
}
|
||||
|
||||
func (he *HookedWebSocketEvent) Event() *model.WebSocketEvent {
|
||||
if he.copy == nil {
|
||||
return he.original
|
||||
}
|
||||
|
||||
return he.copy
|
||||
}
|
206
server/channels/app/platform/web_broadcast_hook_test.go
Normal file
206
server/channels/app/platform/web_broadcast_hook_test.go
Normal file
@ -0,0 +1,206 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package platform
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const broadcastTest = "test_broadcast_hook"
|
||||
|
||||
type testBroadcastHook struct{}
|
||||
|
||||
func (h *testBroadcastHook) Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error {
|
||||
if args["makes_changes"].(bool) {
|
||||
changesMade, _ := msg.Get("changes_made").(int)
|
||||
msg.Add("changes_made", changesMade+1)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRunBroadcastHooks(t *testing.T) {
|
||||
hub := &Hub{
|
||||
broadcastHooks: map[string]BroadcastHook{
|
||||
broadcastTest: &testBroadcastHook{},
|
||||
},
|
||||
}
|
||||
webConn := &WebConn{}
|
||||
|
||||
t.Run("should not allocate a new object when no hooks are passed", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, nil, nil)
|
||||
|
||||
assert.Same(t, event, result)
|
||||
})
|
||||
|
||||
t.Run("should not allocate a new object when a hook is not making changes", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
hookIDs := []string{
|
||||
broadcastTest,
|
||||
}
|
||||
hookArgs := []map[string]any{
|
||||
{
|
||||
"makes_changes": false,
|
||||
},
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
assert.Same(t, event, result)
|
||||
})
|
||||
|
||||
t.Run("should allocate a new object and remove when a hook makes changes", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
hookIDs := []string{
|
||||
broadcastTest,
|
||||
}
|
||||
hookArgs := []map[string]any{
|
||||
{
|
||||
"makes_changes": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
assert.NotSame(t, event, result)
|
||||
assert.NotSame(t, event.GetData(), result.GetData())
|
||||
assert.Equal(t, map[string]any{}, event.GetData())
|
||||
assert.Equal(t, result.GetData(), map[string]any{
|
||||
"changes_made": 1,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("should not allocate a new object when multiple hooks are not making changes", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
hookIDs := []string{
|
||||
broadcastTest,
|
||||
broadcastTest,
|
||||
broadcastTest,
|
||||
}
|
||||
hookArgs := []map[string]any{
|
||||
{
|
||||
"makes_changes": false,
|
||||
},
|
||||
{
|
||||
"makes_changes": false,
|
||||
},
|
||||
{
|
||||
"makes_changes": false,
|
||||
},
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
assert.Same(t, event, result)
|
||||
})
|
||||
|
||||
t.Run("should be able to make changes from only one of make hooks", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
var hookIDs []string
|
||||
var hookArgs []map[string]any
|
||||
for i := 0; i < 10; i++ {
|
||||
hookIDs = append(hookIDs, broadcastTest)
|
||||
hookArgs = append(hookArgs, map[string]any{
|
||||
"makes_changes": i == 6,
|
||||
})
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
assert.NotSame(t, event, result)
|
||||
assert.NotSame(t, event.GetData(), result.GetData())
|
||||
assert.Equal(t, event.GetData(), map[string]any{})
|
||||
assert.Equal(t, result.GetData(), map[string]any{
|
||||
"changes_made": 1,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("should be able to make changes from multiple hooks", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
|
||||
var hookIDs []string
|
||||
var hookArgs []map[string]any
|
||||
for i := 0; i < 10; i++ {
|
||||
hookIDs = append(hookIDs, broadcastTest)
|
||||
hookArgs = append(hookArgs, map[string]any{
|
||||
"makes_changes": true,
|
||||
})
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
assert.NotSame(t, event, result)
|
||||
assert.NotSame(t, event.GetData(), result.GetData())
|
||||
assert.Equal(t, event.GetData(), map[string]any{})
|
||||
assert.Equal(t, result.GetData(), map[string]any{
|
||||
"changes_made": 10,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("should not remove precomputed JSON when a hook doesn't make changes", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
event = event.PrecomputeJSON()
|
||||
|
||||
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
|
||||
originalJSON, _ := event.ToJSON()
|
||||
event.Add("data", 1234)
|
||||
eventJSON, _ := event.ToJSON()
|
||||
require.Equal(t, string(originalJSON), string(eventJSON))
|
||||
|
||||
hookIDs := []string{
|
||||
broadcastTest,
|
||||
}
|
||||
hookArgs := []map[string]any{
|
||||
{
|
||||
"makes_changes": false,
|
||||
},
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
eventJSON, _ = event.ToJSON()
|
||||
assert.Equal(t, string(originalJSON), string(eventJSON))
|
||||
|
||||
resultJSON, _ := result.ToJSON()
|
||||
assert.Equal(t, originalJSON, resultJSON)
|
||||
})
|
||||
|
||||
t.Run("should remove precomputed JSON when a hook makes changes", func(t *testing.T) {
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
event = event.PrecomputeJSON()
|
||||
|
||||
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
|
||||
originalJSON, _ := event.ToJSON()
|
||||
event.Add("data", 1234)
|
||||
eventJSON, _ := event.ToJSON()
|
||||
require.Equal(t, originalJSON, eventJSON)
|
||||
|
||||
hookIDs := []string{
|
||||
broadcastTest,
|
||||
}
|
||||
hookArgs := []map[string]any{
|
||||
{
|
||||
"makes_changes": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
|
||||
|
||||
eventJSON, _ = event.ToJSON()
|
||||
assert.Equal(t, string(originalJSON), string(eventJSON))
|
||||
|
||||
resultJSON, _ := result.ToJSON()
|
||||
assert.NotEqual(t, originalJSON, resultJSON)
|
||||
})
|
||||
}
|
@ -70,6 +70,7 @@ type Hub struct {
|
||||
explicitStop bool
|
||||
checkRegistered chan *webConnSessionMessage
|
||||
checkConn chan *webConnCheckMessage
|
||||
broadcastHooks map[string]BroadcastHook
|
||||
}
|
||||
|
||||
// newWebHub creates a new Hub.
|
||||
@ -90,7 +91,7 @@ func newWebHub(ps *PlatformService) *Hub {
|
||||
}
|
||||
|
||||
// hubStart starts all the hubs.
|
||||
func (ps *PlatformService) hubStart() {
|
||||
func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) {
|
||||
// Total number of hubs is twice the number of CPUs.
|
||||
numberOfHubs := runtime.NumCPU() * 2
|
||||
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
|
||||
@ -100,6 +101,7 @@ func (ps *PlatformService) hubStart() {
|
||||
for i := 0; i < numberOfHubs; i++ {
|
||||
hubs[i] = newWebHub(ps)
|
||||
hubs[i].connectionIndex = i
|
||||
hubs[i].broadcastHooks = broadcastHooks
|
||||
hubs[i].Start()
|
||||
}
|
||||
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
|
||||
@ -492,14 +494,19 @@ func (h *Hub) Start() {
|
||||
if metrics := h.platform.metricsIFace; metrics != nil {
|
||||
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
|
||||
}
|
||||
|
||||
// Remove the broadcast hook information before precomputing the JSON so that those aren't included in it
|
||||
msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks()
|
||||
|
||||
msg = msg.PrecomputeJSON()
|
||||
|
||||
broadcast := func(webConn *WebConn) {
|
||||
if !connIndex.Has(webConn) {
|
||||
return
|
||||
}
|
||||
if webConn.ShouldSendEvent(msg) {
|
||||
select {
|
||||
case webConn.send <- msg:
|
||||
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
|
||||
default:
|
||||
// Don't log the warning if it's an inactive connection.
|
||||
if webConn.active.Load() {
|
||||
|
@ -4,6 +4,7 @@
|
||||
package platform
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -69,7 +70,7 @@ func TestHubStopWithMultipleConnections(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
th.Service.Start()
|
||||
th.Service.Start(nil)
|
||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
@ -93,7 +94,7 @@ func TestHubStopRaceCondition(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
th.Service.Start()
|
||||
th.Service.Start(nil)
|
||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
defer wc1.Close()
|
||||
|
||||
@ -476,7 +477,7 @@ func TestHubIsRegistered(t *testing.T) {
|
||||
s := httptest.NewServer(dummyWebsocketHandler(t))
|
||||
defer s.Close()
|
||||
|
||||
th.Service.Start()
|
||||
th.Service.Start(nil)
|
||||
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
|
||||
@ -583,7 +584,7 @@ func BenchmarkGetHubForUserId(b *testing.B) {
|
||||
th := Setup(b).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
th.Service.Start()
|
||||
th.Service.Start(nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -618,3 +619,54 @@ func TestClusterBroadcast(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clusterEvent.Broadcast, broadcast)
|
||||
}
|
||||
|
||||
func TestClusterBroadcastHooks(t *testing.T) {
|
||||
t.Run("should send broadcast hook information across cluster", func(t *testing.T) {
|
||||
testCluster := &testlib.FakeClusterInterface{}
|
||||
|
||||
th := SetupWithCluster(t, testCluster)
|
||||
defer th.TearDown()
|
||||
|
||||
hookID := broadcastTest
|
||||
hookArgs := map[string]any{
|
||||
"makes_changes": true,
|
||||
}
|
||||
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
event.GetBroadcast().AddHook(hookID, hookArgs)
|
||||
|
||||
th.Service.Publish(event)
|
||||
|
||||
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
|
||||
assert.Equal(t, []map[string]any{hookArgs}, received.GetBroadcast().BroadcastHookArgs)
|
||||
})
|
||||
|
||||
t.Run("should not preserve type information for args", func(t *testing.T) {
|
||||
// This behaviour isn't ideal, but this test confirms that it hasn't changed
|
||||
testCluster := &testlib.FakeClusterInterface{}
|
||||
|
||||
th := SetupWithCluster(t, testCluster)
|
||||
defer th.TearDown()
|
||||
|
||||
hookID := "test_broadcast_hook_with_args"
|
||||
hookArgs := map[string]any{
|
||||
"user": &model.User{Id: "user1"},
|
||||
"array": []string{"a", "b", "c"},
|
||||
}
|
||||
|
||||
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
|
||||
event.GetBroadcast().AddHook(hookID, hookArgs)
|
||||
|
||||
th.Service.Publish(event)
|
||||
|
||||
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
|
||||
assert.IsType(t, map[string]any{}, received.GetBroadcast().BroadcastHookArgs[0]["user"])
|
||||
assert.IsType(t, []any{}, received.GetBroadcast().BroadcastHookArgs[0]["array"])
|
||||
})
|
||||
}
|
||||
|
@ -270,7 +270,7 @@ func NewServer(options ...Option) (*Server, error) {
|
||||
// It is important to initialize the hub only after the global logger is set
|
||||
// to avoid race conditions while logging from inside the hub.
|
||||
// Step 4: Start platform
|
||||
s.platform.Start()
|
||||
s.platform.Start(s.makeBroadcastHooks())
|
||||
|
||||
// NOTE: There should be no call to App.Srv().Channels() before step 5 is done
|
||||
// otherwise it will throw a panic.
|
||||
|
96
server/channels/app/web_broadcast_hooks.go
Normal file
96
server/channels/app/web_broadcast_hooks.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
pUtils "github.com/mattermost/mattermost/server/public/utils"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
broadcastAddMentions = "add_mentions"
|
||||
broadcastAddFollowers = "add_followers"
|
||||
)
|
||||
|
||||
func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook {
|
||||
return map[string]platform.BroadcastHook{
|
||||
broadcastAddMentions: &addMentionsBroadcastHook{},
|
||||
broadcastAddFollowers: &addFollowersBroadcastHook{},
|
||||
}
|
||||
}
|
||||
|
||||
type addMentionsBroadcastHook struct{}
|
||||
|
||||
func (h *addMentionsBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
|
||||
mentions, err := getTypedArg[model.StringArray](args, "mentions")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Invalid mentions value passed to addMentionsBroadcastHook")
|
||||
}
|
||||
|
||||
if len(mentions) > 0 && pUtils.Contains[string](mentions, webConn.UserId) {
|
||||
// Note that the client expects this field to be stringified
|
||||
msg.Add("mentions", model.ArrayToJSON([]string{webConn.UserId}))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func useAddMentionsHook(message *model.WebSocketEvent, mentionedUsers model.StringArray) {
|
||||
message.GetBroadcast().AddHook(broadcastAddMentions, map[string]any{
|
||||
"mentions": mentionedUsers,
|
||||
})
|
||||
}
|
||||
|
||||
type addFollowersBroadcastHook struct{}
|
||||
|
||||
func (h *addFollowersBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
|
||||
followers, err := getTypedArg[model.StringArray](args, "followers")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Invalid followers value passed to addFollowersBroadcastHook")
|
||||
}
|
||||
|
||||
if len(followers) > 0 && pUtils.Contains[string](followers, webConn.UserId) {
|
||||
// Note that the client expects this field to be stringified
|
||||
msg.Add("followers", model.ArrayToJSON([]string{webConn.UserId}))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func useAddFollowersHook(message *model.WebSocketEvent, followers model.StringArray) {
|
||||
message.GetBroadcast().AddHook(broadcastAddFollowers, map[string]any{
|
||||
"followers": followers,
|
||||
})
|
||||
}
|
||||
|
||||
// getTypedArg returns a correctly typed hook argument with the given key, reinterpreting the type using JSON encoding
|
||||
// if necessary. This is needed because broadcast hook args are JSON encoded in a multi-server environment, and any
|
||||
// type information is lost because those types aren't known at decode time.
|
||||
func getTypedArg[T any](args map[string]any, key string) (T, error) {
|
||||
var value T
|
||||
|
||||
untyped, ok := args[key]
|
||||
if !ok {
|
||||
return value, fmt.Errorf("No argument found with key: %s", key)
|
||||
}
|
||||
|
||||
// If the value is already correct, just return it
|
||||
if typed, ok := untyped.(T); ok {
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
// Marshal and unmarshal the data with the correct typing information
|
||||
buf, err := json.Marshal(untyped)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(buf, &value)
|
||||
return value, err
|
||||
}
|
114
server/channels/app/web_broadcast_hooks_test.go
Normal file
114
server/channels/app/web_broadcast_hooks_test.go
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
||||
// See LICENSE.txt for license information.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddMentionsHook_Process(t *testing.T) {
|
||||
hook := &addMentionsBroadcastHook{}
|
||||
|
||||
userID := model.NewId()
|
||||
otherUserID := model.NewId()
|
||||
|
||||
webConn := &platform.WebConn{
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
t.Run("should add a mentions entry for the current user", func(t *testing.T) {
|
||||
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||
|
||||
require.Nil(t, msg.Event().GetData()["mentions"])
|
||||
|
||||
hook.Process(msg, webConn, map[string]any{
|
||||
"mentions": model.StringArray{userID},
|
||||
})
|
||||
|
||||
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
|
||||
assert.Nil(t, msg.Event().GetData()["followers"])
|
||||
})
|
||||
|
||||
t.Run("should not add a mentions entry for another user", func(t *testing.T) {
|
||||
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||
|
||||
require.Nil(t, msg.Event().GetData()["mentions"])
|
||||
|
||||
hook.Process(msg, webConn, map[string]any{
|
||||
"mentions": model.StringArray{otherUserID},
|
||||
})
|
||||
|
||||
assert.Nil(t, msg.Event().GetData()["mentions"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddFollowersHook_Process(t *testing.T) {
|
||||
hook := &addFollowersBroadcastHook{}
|
||||
|
||||
userID := model.NewId()
|
||||
otherUserID := model.NewId()
|
||||
|
||||
webConn := &platform.WebConn{
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
t.Run("should add a followers entry for the current user", func(t *testing.T) {
|
||||
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||
|
||||
require.Nil(t, msg.Event().GetData()["followers"])
|
||||
|
||||
hook.Process(msg, webConn, map[string]any{
|
||||
"followers": model.StringArray{userID},
|
||||
})
|
||||
|
||||
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
|
||||
})
|
||||
|
||||
t.Run("should not add a followers entry for another user", func(t *testing.T) {
|
||||
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||
|
||||
require.Nil(t, msg.Event().GetData()["followers"])
|
||||
|
||||
hook.Process(msg, webConn, map[string]any{
|
||||
"followers": model.StringArray{otherUserID},
|
||||
})
|
||||
|
||||
assert.Nil(t, msg.Event().GetData()["followers"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddMentionsAndAddFollowersHooks(t *testing.T) {
|
||||
addMentionsHook := &addMentionsBroadcastHook{}
|
||||
addFollowersHook := &addFollowersBroadcastHook{}
|
||||
|
||||
userID := model.NewId()
|
||||
|
||||
webConn := &platform.WebConn{
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
|
||||
|
||||
originalData := msg.Event().GetData()
|
||||
|
||||
require.Nil(t, originalData["mentions"])
|
||||
require.Nil(t, originalData["followers"])
|
||||
|
||||
addMentionsHook.Process(msg, webConn, map[string]any{
|
||||
"mentions": model.StringArray{userID},
|
||||
})
|
||||
addFollowersHook.Process(msg, webConn, map[string]any{
|
||||
"followers": model.StringArray{userID},
|
||||
})
|
||||
|
||||
t.Run("should be able to add both mentions and followers to a single event", func(t *testing.T) {
|
||||
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
|
||||
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
|
||||
})
|
||||
}
|
@ -104,6 +104,17 @@ type WebsocketBroadcast struct {
|
||||
// ReliableClusterSend indicates whether or not the message should
|
||||
// be sent through the cluster using the reliable, TCP backed channel.
|
||||
ReliableClusterSend bool `json:"-"`
|
||||
|
||||
// BroadcastHooks is a slice of hooks IDs used to process events before sending them on individual connections. The
|
||||
// IDs should be understood by the WebSocket code.
|
||||
//
|
||||
// This field should never be sent to the client.
|
||||
BroadcastHooks []string `json:"broadcast_hooks,omitempty"`
|
||||
// BroadcastHookArgs is a slice of named arguments for each hook invocation. The index of each entry corresponds to
|
||||
// the index of a hook ID in BroadcastHooks
|
||||
//
|
||||
// This field should never be sent to the client.
|
||||
BroadcastHookArgs []map[string]any `json:"broadcast_hook_args,omitempty"`
|
||||
}
|
||||
|
||||
func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
|
||||
@ -124,10 +135,17 @@ func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
|
||||
c.OmitConnectionId = wb.OmitConnectionId
|
||||
c.ContainsSanitizedData = wb.ContainsSanitizedData
|
||||
c.ContainsSensitiveData = wb.ContainsSensitiveData
|
||||
c.BroadcastHooks = wb.BroadcastHooks
|
||||
c.BroadcastHookArgs = wb.BroadcastHookArgs
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (wb *WebsocketBroadcast) AddHook(hookID string, hookArgs map[string]any) {
|
||||
wb.BroadcastHooks = append(wb.BroadcastHooks, hookID)
|
||||
wb.BroadcastHookArgs = append(wb.BroadcastHookArgs, hookArgs)
|
||||
}
|
||||
|
||||
type precomputedWebSocketEventJSON struct {
|
||||
Event json.RawMessage
|
||||
Data json.RawMessage
|
||||
@ -190,6 +208,32 @@ func (ev *WebSocketEvent) PrecomputeJSON() *WebSocketEvent {
|
||||
return evCopy
|
||||
}
|
||||
|
||||
func (ev *WebSocketEvent) RemovePrecomputedJSON() *WebSocketEvent {
|
||||
evCopy := ev.DeepCopy()
|
||||
evCopy.precomputedJSON = nil
|
||||
return evCopy
|
||||
}
|
||||
|
||||
// WithoutBroadcastHooks gets the broadcast hook information from a WebSocketEvent and returns the event without that.
|
||||
// If the event has broadcast hooks, a copy of the event is returned. Otherwise, the original event is returned. This
|
||||
// is intended to be called before the event is sent to the client.
|
||||
func (ev *WebSocketEvent) WithoutBroadcastHooks() (*WebSocketEvent, []string, []map[string]any) {
|
||||
hooks := ev.broadcast.BroadcastHooks
|
||||
hookArgs := ev.broadcast.BroadcastHookArgs
|
||||
|
||||
if len(hooks) == 0 && len(hookArgs) == 0 {
|
||||
return ev, hooks, hookArgs
|
||||
}
|
||||
|
||||
evCopy := ev.Copy()
|
||||
evCopy.broadcast = ev.broadcast.copy()
|
||||
|
||||
evCopy.broadcast.BroadcastHooks = nil
|
||||
evCopy.broadcast.BroadcastHookArgs = nil
|
||||
|
||||
return evCopy, hooks, hookArgs
|
||||
}
|
||||
|
||||
func (ev *WebSocketEvent) Add(key string, value any) {
|
||||
ev.data[key] = value
|
||||
}
|
||||
@ -219,17 +263,9 @@ func (ev *WebSocketEvent) Copy() *WebSocketEvent {
|
||||
}
|
||||
|
||||
func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
|
||||
var dataCopy map[string]any
|
||||
if ev.data != nil {
|
||||
dataCopy = make(map[string]any, len(ev.data))
|
||||
for k, v := range ev.data {
|
||||
dataCopy[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
evCopy := &WebSocketEvent{
|
||||
event: ev.event,
|
||||
data: dataCopy,
|
||||
data: copyMap(ev.data),
|
||||
broadcast: ev.broadcast.copy(),
|
||||
sequence: ev.sequence,
|
||||
precomputedJSON: ev.precomputedJSON.copy(),
|
||||
@ -237,6 +273,14 @@ func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
|
||||
return evCopy
|
||||
}
|
||||
|
||||
func copyMap[K comparable, V any](m map[K]V) map[K]V {
|
||||
dataCopy := make(map[K]V, len(m))
|
||||
for k, v := range m {
|
||||
dataCopy[k] = v
|
||||
}
|
||||
return dataCopy
|
||||
}
|
||||
|
||||
func (ev *WebSocketEvent) GetData() map[string]any {
|
||||
return ev.data
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user