MM-54238 Add WebSocket broadcast hook and don't broadcast when other users were mentioned (#24641)

* MM-54238 Initial implementation

* MM-54238 Move websocket hook into app package

* MM-54238 Add tests for mentions in posted websocket messages

* Fix styling

* Fix other styling

* Idiomatic ID naming for new code

* Fix more styles

* Separate hooks to add mentions and followers

* Improved error handling for invalid types in hooks

* Rename HasChanges to ShouldProcess

* Pass broadcast hooks through hubStart

* Add test helper for asserting json unmarshaling

* Fix missing arguments in tests

* Ensure broadcast hooks are sent across the cluster and not to users

* Ensure tests actually cover following a post

* Fix code broken by merge

* Go vet again...

* Deep copy event before processing it with hooks

* Replace RemoveBroadcastHooks with WithoutBroadcastHooks

* Address feedback

* Add helper to fix type information for hook args

* Wrap WebSocketEvent and simplify BroadcastHook

* Address feedback

* Address feedback
This commit is contained in:
Harrison Healey 2023-11-08 16:17:07 -05:00 committed by GitHub
parent 5e62ba8ccc
commit ef66f7beab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 942 additions and 23 deletions

View File

@ -517,12 +517,12 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea
}
}
if len(mentionedUsersList) != 0 {
message.Add("mentions", model.ArrayToJSON(mentionedUsersList))
if len(mentionedUsersList) > 0 {
useAddMentionsHook(message, mentionedUsersList)
}
if len(notificationsForCRT.Desktop) != 0 {
message.Add("followers", model.ArrayToJSON(notificationsForCRT.Desktop))
if len(notificationsForCRT.Desktop) > 0 {
useAddFollowersHook(message, notificationsForCRT.Desktop)
}
published, err := a.publishWebsocketEventForPermalinkPost(c, post, message)

View File

@ -4,10 +4,14 @@
package app
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -15,6 +19,7 @@ import (
"github.com/mattermost/mattermost/server/public/shared/i18n"
pUtils "github.com/mattermost/mattermost/server/public/utils"
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
"github.com/mattermost/mattermost/server/v8/channels/store"
)
@ -212,6 +217,314 @@ func TestSendNotifications(t *testing.T) {
})
}
func TestSendNotifications_MentionsFollowers(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
th.AddUserToChannel(th.BasicUser2, th.BasicChannel)
sender := th.CreateUser()
th.LinkUserToTeam(sender, th.BasicTeam)
member := th.AddUserToChannel(sender, th.BasicChannel)
t.Run("should inform each user if they were mentioned by a post", func(t *testing.T) {
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
defer closeWS1()
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
defer closeWS2()
// First post mentioning the whole channel
post := &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
Message: "@channel",
}
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
require.NoError(t, err)
received1 := <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
received2 := <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
// Second post mentioning both users individually
post = &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
Message: fmt.Sprintf("@%s @%s", th.BasicUser.Username, th.BasicUser2.Username),
}
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
require.NoError(t, err)
received1 = <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
received2 = <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
// Third post mentioning a single user
post = &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
Message: "@" + th.BasicUser.Username,
}
_, err = th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
require.NoError(t, err)
received1 = <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
received2 = <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assert.Nil(t, received2.GetData()["mentions"])
})
t.Run("should inform each user in a group if they were mentioned by a post", func(t *testing.T) {
// Make the sender a channel_admin because that's needed for group mentions
originalRoles := member.Roles
member.Roles = "channel_user channel_admin"
_, appErr := th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, member.Roles)
require.Nil(t, appErr)
defer func() {
th.App.UpdateChannelMemberRoles(th.Context, member.ChannelId, member.UserId, originalRoles)
}()
th.App.Srv().SetLicense(getLicWithSkuShortName(model.LicenseShortSkuEnterprise))
// Make a group and add users
group := th.CreateGroup()
group.AllowReference = true
group, updateErr := th.App.UpdateGroup(group)
require.Nil(t, updateErr)
_, upsertErr := th.App.UpsertGroupMember(group.Id, th.BasicUser.Id)
require.Nil(t, upsertErr)
_, upsertErr = th.App.UpsertGroupMember(group.Id, th.BasicUser2.Id)
require.Nil(t, upsertErr)
// Set up the websockets
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
defer closeWS1()
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
defer closeWS2()
// Confirm permissions for group mentions are correct
post := &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
Message: "@" + *group.Name,
}
require.True(t, th.App.allowGroupMentions(th.Context, post))
// Test sending notifications
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
require.NoError(t, err)
received1 := <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["mentions"])
received2 := <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser2.Id}, received2.GetData()["mentions"])
})
t.Run("should inform each user if they are following a thread that was posted in", func(t *testing.T) {
t.Log("BasicUser ", th.BasicUser.Id)
t.Log("sender ", sender.Id)
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
defer closeWS1()
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
defer closeWS2()
// Reply to a post made by BasicUser
post := &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
RootId: th.BasicPost.Id,
Message: "This is a test",
}
// Use CreatePost instead of SendNotifications here since we need that to set up some threads state
_, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, false, false)
require.Nil(t, appErr)
received1 := <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assertUnmarshalsTo(t, []string{th.BasicUser.Id}, received1.GetData()["followers"])
received2 := <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assert.Nil(t, received2.GetData()["followers"])
})
t.Run("should not include broadcast hook information in messages sent to users", func(t *testing.T) {
messages1, closeWS1 := connectFakeWebSocket(t, th, th.BasicUser.Id, "")
defer closeWS1()
messages2, closeWS2 := connectFakeWebSocket(t, th, th.BasicUser2.Id, "")
defer closeWS2()
// For a post mentioning only one user, nobody in the channel should receive information about the broadcast hooks
post := &model.Post{
UserId: sender.Id,
ChannelId: th.BasicChannel.Id,
Message: fmt.Sprintf("@%s", th.BasicUser.Username),
}
_, err := th.App.SendNotifications(th.Context, post, th.BasicTeam, th.BasicChannel, sender, nil, false)
require.NoError(t, err)
received1 := <-messages1
require.Equal(t, model.WebsocketEventPosted, received1.EventType())
assert.Nil(t, received1.GetBroadcast().BroadcastHooks)
assert.Nil(t, received1.GetBroadcast().BroadcastHookArgs)
received2 := <-messages2
require.Equal(t, model.WebsocketEventPosted, received2.EventType())
assert.Nil(t, received2.GetBroadcast().BroadcastHooks)
assert.Nil(t, received2.GetBroadcast().BroadcastHookArgs)
})
}
func assertUnmarshalsTo(t *testing.T, expected any, actual any) {
t.Helper()
val, err := json.Marshal(expected)
require.NoError(t, err)
assert.JSONEq(t, string(val), actual.(string))
}
func connectFakeWebSocket(t *testing.T, th *TestHelper, userID string, connectionID string) (chan *model.WebSocketEvent, func()) {
var session *model.Session
var server *httptest.Server
var webConn *platform.WebConn
closeWS := func() {
if webConn != nil {
webConn.Close()
}
if server != nil {
server.Close()
}
if session != nil {
appErr := th.App.RevokeSession(th.Context, session)
require.Nil(t, appErr)
}
}
// Create a session for the user's connection
var appErr *model.AppError
session, appErr = th.App.CreateSession(th.Context, &model.Session{
UserId: userID,
})
require.Nil(t, appErr)
// Create a channel and an HTTP server to handle incoming WS events
messages := make(chan *model.WebSocketEvent)
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrader := &websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Log("Received error when upgrading WebSocket connection", err)
return
}
defer c.Close()
for {
_, reader, err := c.NextReader()
if err != nil {
t.Log("Received error when reading from WebSocket connection", err)
break
}
msg, err := model.WebSocketEventFromJSON(reader)
if err != nil {
t.Log("Received error when decoding from WebSocket connection", err)
break
}
messages <- msg
}
}))
// Connect the WebSocket
d := websocket.Dialer{}
ws, _, err := d.Dial("ws://"+server.Listener.Addr().String(), nil)
require.NoError(t, err)
// Register the WebSocket with the server as a WebConn
if connectionID == "" {
connectionID = model.NewId()
}
webConn = th.App.Srv().Platform().NewWebConn(&platform.WebConnConfig{
WebSocket: ws,
Session: *session,
TFunc: i18n.IdentityTfunc(),
Locale: "en",
ConnectionID: connectionID,
}, th.App, th.App.Channels())
th.App.Srv().Platform().HubRegister(webConn)
// Start reading from it
go webConn.Pump()
// Read the events which always occur at the start of a WebSocket connection
received := <-messages
assert.Equal(t, model.WebsocketEventHello, received.EventType())
received = <-messages
assert.Equal(t, model.WebsocketEventStatusChange, received.EventType())
return messages, closeWS
}
func TestConnectFakeWebSocket(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
teamID := th.BasicTeam.Id
userID := th.BasicUser.Id
messages, closeWS := connectFakeWebSocket(t, th, userID, "")
defer closeWS()
msg := model.NewWebSocketEvent(model.WebsocketEventPosted, teamID, "", "", nil, "")
th.App.Publish(msg)
msg = model.NewWebSocketEvent("test_event_with_data", "", "", userID, nil, "")
msg.Add("key1", "value1")
msg.Add("key2", 2)
msg.Add("key3", []string{"three", "trois"})
th.App.Publish(msg)
received := <-messages
require.Equal(t, model.WebsocketEventPosted, received.EventType())
assert.Equal(t, teamID, received.GetBroadcast().TeamId)
received = <-messages
require.Equal(t, "test_event_with_data", received.EventType())
assert.Equal(t, userID, received.GetBroadcast().UserId)
// These type changes are annoying but unavoidable because event data is untyped
assert.Equal(t, map[string]any{
"key1": "value1",
"key2": float64(2),
"key3": []any{"three", "trois"},
}, received.GetData())
}
func TestSendNotificationsWithManyUsers(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()

View File

@ -187,7 +187,7 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
th.Service.SetLicense(nil)
}
err = th.Service.Start()
err = th.Service.Start(nil)
if err != nil {
panic(err)
}

View File

@ -344,8 +344,8 @@ func New(sc ServiceConfig, options ...Option) (*PlatformService, error) {
return ps, nil
}
func (ps *PlatformService) Start() error {
ps.hubStart()
func (ps *PlatformService) Start(broadcastHooks map[string]BroadcastHook) error {
ps.hubStart(broadcastHooks)
ps.configListenerId = ps.AddConfigListener(func(_, _ *model.Config) {
ps.regenerateClientConfig()

View File

@ -0,0 +1,87 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
type BroadcastHook interface {
// Process takes a WebSocket event and modifies it in some way. It is passed a HookedWebSocketEvent which allows
// safe modification of the event.
Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error
}
func (h *Hub) runBroadcastHooks(msg *model.WebSocketEvent, webConn *WebConn, hookIDs []string, hookArgs []map[string]any) *model.WebSocketEvent {
if len(hookIDs) == 0 {
return msg
}
hookedEvent := MakeHookedWebSocketEvent(msg)
for i, hookID := range hookIDs {
hook := h.broadcastHooks[hookID]
args := hookArgs[i]
if hook == nil {
mlog.Warn("runBroadcastHooks: Unable to find broadcast hook", mlog.String("hook_id", hookID))
continue
}
hook.Process(hookedEvent, webConn, args)
}
return hookedEvent.Event()
}
// HookedWebSocketEvent is a wrapper for model.WebSocketEvent that is intended to provide a similar interface, except
// it ensures the original WebSocket event is not modified.
type HookedWebSocketEvent struct {
original *model.WebSocketEvent
copy *model.WebSocketEvent
}
func MakeHookedWebSocketEvent(event *model.WebSocketEvent) *HookedWebSocketEvent {
return &HookedWebSocketEvent{
original: event,
}
}
func (he *HookedWebSocketEvent) Add(key string, value any) {
he.copyIfNecessary()
he.copy.Add(key, value)
}
func (he *HookedWebSocketEvent) EventType() string {
if he.copy == nil {
return he.original.EventType()
}
return he.copy.EventType()
}
// Get returns a value from the WebSocket event data. You should never mutate a value returned by this method.
func (he *HookedWebSocketEvent) Get(key string) any {
if he.copy == nil {
return he.original.GetData()[key]
}
return he.copy.GetData()[key]
}
// copyIfNecessary should be called by any mutative method to ensure that the copy is instantiated.
func (he *HookedWebSocketEvent) copyIfNecessary() {
if he.copy == nil {
he.copy = he.original.RemovePrecomputedJSON()
}
}
func (he *HookedWebSocketEvent) Event() *model.WebSocketEvent {
if he.copy == nil {
return he.original
}
return he.copy
}

View File

@ -0,0 +1,206 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const broadcastTest = "test_broadcast_hook"
type testBroadcastHook struct{}
func (h *testBroadcastHook) Process(msg *HookedWebSocketEvent, webConn *WebConn, args map[string]any) error {
if args["makes_changes"].(bool) {
changesMade, _ := msg.Get("changes_made").(int)
msg.Add("changes_made", changesMade+1)
}
return nil
}
func TestRunBroadcastHooks(t *testing.T) {
hub := &Hub{
broadcastHooks: map[string]BroadcastHook{
broadcastTest: &testBroadcastHook{},
},
}
webConn := &WebConn{}
t.Run("should not allocate a new object when no hooks are passed", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
result := hub.runBroadcastHooks(event, webConn, nil, nil)
assert.Same(t, event, result)
})
t.Run("should not allocate a new object when a hook is not making changes", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
hookIDs := []string{
broadcastTest,
}
hookArgs := []map[string]any{
{
"makes_changes": false,
},
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
assert.Same(t, event, result)
})
t.Run("should allocate a new object and remove when a hook makes changes", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
hookIDs := []string{
broadcastTest,
}
hookArgs := []map[string]any{
{
"makes_changes": true,
},
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
assert.NotSame(t, event, result)
assert.NotSame(t, event.GetData(), result.GetData())
assert.Equal(t, map[string]any{}, event.GetData())
assert.Equal(t, result.GetData(), map[string]any{
"changes_made": 1,
})
})
t.Run("should not allocate a new object when multiple hooks are not making changes", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
hookIDs := []string{
broadcastTest,
broadcastTest,
broadcastTest,
}
hookArgs := []map[string]any{
{
"makes_changes": false,
},
{
"makes_changes": false,
},
{
"makes_changes": false,
},
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
assert.Same(t, event, result)
})
t.Run("should be able to make changes from only one of make hooks", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
var hookIDs []string
var hookArgs []map[string]any
for i := 0; i < 10; i++ {
hookIDs = append(hookIDs, broadcastTest)
hookArgs = append(hookArgs, map[string]any{
"makes_changes": i == 6,
})
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
assert.NotSame(t, event, result)
assert.NotSame(t, event.GetData(), result.GetData())
assert.Equal(t, event.GetData(), map[string]any{})
assert.Equal(t, result.GetData(), map[string]any{
"changes_made": 1,
})
})
t.Run("should be able to make changes from multiple hooks", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
var hookIDs []string
var hookArgs []map[string]any
for i := 0; i < 10; i++ {
hookIDs = append(hookIDs, broadcastTest)
hookArgs = append(hookArgs, map[string]any{
"makes_changes": true,
})
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
assert.NotSame(t, event, result)
assert.NotSame(t, event.GetData(), result.GetData())
assert.Equal(t, event.GetData(), map[string]any{})
assert.Equal(t, result.GetData(), map[string]any{
"changes_made": 10,
})
})
t.Run("should not remove precomputed JSON when a hook doesn't make changes", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
event = event.PrecomputeJSON()
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
originalJSON, _ := event.ToJSON()
event.Add("data", 1234)
eventJSON, _ := event.ToJSON()
require.Equal(t, string(originalJSON), string(eventJSON))
hookIDs := []string{
broadcastTest,
}
hookArgs := []map[string]any{
{
"makes_changes": false,
},
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
eventJSON, _ = event.ToJSON()
assert.Equal(t, string(originalJSON), string(eventJSON))
resultJSON, _ := result.ToJSON()
assert.Equal(t, originalJSON, resultJSON)
})
t.Run("should remove precomputed JSON when a hook makes changes", func(t *testing.T) {
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
event = event.PrecomputeJSON()
// Ensure that the event has precomputed JSON because changes aren't included when ToJSON is called again
originalJSON, _ := event.ToJSON()
event.Add("data", 1234)
eventJSON, _ := event.ToJSON()
require.Equal(t, originalJSON, eventJSON)
hookIDs := []string{
broadcastTest,
}
hookArgs := []map[string]any{
{
"makes_changes": true,
},
}
result := hub.runBroadcastHooks(event, webConn, hookIDs, hookArgs)
eventJSON, _ = event.ToJSON()
assert.Equal(t, string(originalJSON), string(eventJSON))
resultJSON, _ := result.ToJSON()
assert.NotEqual(t, originalJSON, resultJSON)
})
}

View File

@ -70,6 +70,7 @@ type Hub struct {
explicitStop bool
checkRegistered chan *webConnSessionMessage
checkConn chan *webConnCheckMessage
broadcastHooks map[string]BroadcastHook
}
// newWebHub creates a new Hub.
@ -90,7 +91,7 @@ func newWebHub(ps *PlatformService) *Hub {
}
// hubStart starts all the hubs.
func (ps *PlatformService) hubStart() {
func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) {
// Total number of hubs is twice the number of CPUs.
numberOfHubs := runtime.NumCPU() * 2
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
@ -100,6 +101,7 @@ func (ps *PlatformService) hubStart() {
for i := 0; i < numberOfHubs; i++ {
hubs[i] = newWebHub(ps)
hubs[i].connectionIndex = i
hubs[i].broadcastHooks = broadcastHooks
hubs[i].Start()
}
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
@ -492,14 +494,19 @@ func (h *Hub) Start() {
if metrics := h.platform.metricsIFace; metrics != nil {
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
}
// Remove the broadcast hook information before precomputing the JSON so that those aren't included in it
msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks()
msg = msg.PrecomputeJSON()
broadcast := func(webConn *WebConn) {
if !connIndex.Has(webConn) {
return
}
if webConn.ShouldSendEvent(msg) {
select {
case webConn.send <- msg:
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
default:
// Don't log the warning if it's an inactive connection.
if webConn.active.Load() {

View File

@ -4,6 +4,7 @@
package platform
import (
"bytes"
"encoding/json"
"net"
"net/http"
@ -69,7 +70,7 @@ func TestHubStopWithMultipleConnections(t *testing.T) {
})
require.NoError(t, err)
th.Service.Start()
th.Service.Start(nil)
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
@ -93,7 +94,7 @@ func TestHubStopRaceCondition(t *testing.T) {
})
require.NoError(t, err)
th.Service.Start()
th.Service.Start(nil)
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
defer wc1.Close()
@ -476,7 +477,7 @@ func TestHubIsRegistered(t *testing.T) {
s := httptest.NewServer(dummyWebsocketHandler(t))
defer s.Close()
th.Service.Start()
th.Service.Start(nil)
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc2 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc3 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
@ -583,7 +584,7 @@ func BenchmarkGetHubForUserId(b *testing.B) {
th := Setup(b).InitBasic()
defer th.TearDown()
th.Service.Start()
th.Service.Start(nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -618,3 +619,54 @@ func TestClusterBroadcast(t *testing.T) {
require.NoError(t, err)
require.Equal(t, clusterEvent.Broadcast, broadcast)
}
func TestClusterBroadcastHooks(t *testing.T) {
t.Run("should send broadcast hook information across cluster", func(t *testing.T) {
testCluster := &testlib.FakeClusterInterface{}
th := SetupWithCluster(t, testCluster)
defer th.TearDown()
hookID := broadcastTest
hookArgs := map[string]any{
"makes_changes": true,
}
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
event.GetBroadcast().AddHook(hookID, hookArgs)
th.Service.Publish(event)
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
require.NoError(t, err)
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
assert.Equal(t, []map[string]any{hookArgs}, received.GetBroadcast().BroadcastHookArgs)
})
t.Run("should not preserve type information for args", func(t *testing.T) {
// This behaviour isn't ideal, but this test confirms that it hasn't changed
testCluster := &testlib.FakeClusterInterface{}
th := SetupWithCluster(t, testCluster)
defer th.TearDown()
hookID := "test_broadcast_hook_with_args"
hookArgs := map[string]any{
"user": &model.User{Id: "user1"},
"array": []string{"a", "b", "c"},
}
event := model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, "")
event.GetBroadcast().AddHook(hookID, hookArgs)
th.Service.Publish(event)
received, err := model.WebSocketEventFromJSON(bytes.NewReader(testCluster.GetMessages()[0].Data))
require.NoError(t, err)
assert.Equal(t, []string{hookID}, received.GetBroadcast().BroadcastHooks)
assert.IsType(t, map[string]any{}, received.GetBroadcast().BroadcastHookArgs[0]["user"])
assert.IsType(t, []any{}, received.GetBroadcast().BroadcastHookArgs[0]["array"])
})
}

View File

@ -270,7 +270,7 @@ func NewServer(options ...Option) (*Server, error) {
// It is important to initialize the hub only after the global logger is set
// to avoid race conditions while logging from inside the hub.
// Step 4: Start platform
s.platform.Start()
s.platform.Start(s.makeBroadcastHooks())
// NOTE: There should be no call to App.Srv().Channels() before step 5 is done
// otherwise it will throw a panic.

View File

@ -0,0 +1,96 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package app
import (
"encoding/json"
"fmt"
"github.com/mattermost/mattermost/server/public/model"
pUtils "github.com/mattermost/mattermost/server/public/utils"
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
"github.com/pkg/errors"
)
const (
broadcastAddMentions = "add_mentions"
broadcastAddFollowers = "add_followers"
)
func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook {
return map[string]platform.BroadcastHook{
broadcastAddMentions: &addMentionsBroadcastHook{},
broadcastAddFollowers: &addFollowersBroadcastHook{},
}
}
type addMentionsBroadcastHook struct{}
func (h *addMentionsBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
mentions, err := getTypedArg[model.StringArray](args, "mentions")
if err != nil {
return errors.Wrap(err, "Invalid mentions value passed to addMentionsBroadcastHook")
}
if len(mentions) > 0 && pUtils.Contains[string](mentions, webConn.UserId) {
// Note that the client expects this field to be stringified
msg.Add("mentions", model.ArrayToJSON([]string{webConn.UserId}))
}
return nil
}
func useAddMentionsHook(message *model.WebSocketEvent, mentionedUsers model.StringArray) {
message.GetBroadcast().AddHook(broadcastAddMentions, map[string]any{
"mentions": mentionedUsers,
})
}
type addFollowersBroadcastHook struct{}
func (h *addFollowersBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error {
followers, err := getTypedArg[model.StringArray](args, "followers")
if err != nil {
return errors.Wrap(err, "Invalid followers value passed to addFollowersBroadcastHook")
}
if len(followers) > 0 && pUtils.Contains[string](followers, webConn.UserId) {
// Note that the client expects this field to be stringified
msg.Add("followers", model.ArrayToJSON([]string{webConn.UserId}))
}
return nil
}
func useAddFollowersHook(message *model.WebSocketEvent, followers model.StringArray) {
message.GetBroadcast().AddHook(broadcastAddFollowers, map[string]any{
"followers": followers,
})
}
// getTypedArg returns a correctly typed hook argument with the given key, reinterpreting the type using JSON encoding
// if necessary. This is needed because broadcast hook args are JSON encoded in a multi-server environment, and any
// type information is lost because those types aren't known at decode time.
func getTypedArg[T any](args map[string]any, key string) (T, error) {
var value T
untyped, ok := args[key]
if !ok {
return value, fmt.Errorf("No argument found with key: %s", key)
}
// If the value is already correct, just return it
if typed, ok := untyped.(T); ok {
return typed, nil
}
// Marshal and unmarshal the data with the correct typing information
buf, err := json.Marshal(untyped)
if err != nil {
return value, err
}
err = json.Unmarshal(buf, &value)
return value, err
}

View File

@ -0,0 +1,114 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package app
import (
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddMentionsHook_Process(t *testing.T) {
hook := &addMentionsBroadcastHook{}
userID := model.NewId()
otherUserID := model.NewId()
webConn := &platform.WebConn{
UserId: userID,
}
t.Run("should add a mentions entry for the current user", func(t *testing.T) {
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
require.Nil(t, msg.Event().GetData()["mentions"])
hook.Process(msg, webConn, map[string]any{
"mentions": model.StringArray{userID},
})
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
assert.Nil(t, msg.Event().GetData()["followers"])
})
t.Run("should not add a mentions entry for another user", func(t *testing.T) {
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
require.Nil(t, msg.Event().GetData()["mentions"])
hook.Process(msg, webConn, map[string]any{
"mentions": model.StringArray{otherUserID},
})
assert.Nil(t, msg.Event().GetData()["mentions"])
})
}
func TestAddFollowersHook_Process(t *testing.T) {
hook := &addFollowersBroadcastHook{}
userID := model.NewId()
otherUserID := model.NewId()
webConn := &platform.WebConn{
UserId: userID,
}
t.Run("should add a followers entry for the current user", func(t *testing.T) {
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
require.Nil(t, msg.Event().GetData()["followers"])
hook.Process(msg, webConn, map[string]any{
"followers": model.StringArray{userID},
})
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
})
t.Run("should not add a followers entry for another user", func(t *testing.T) {
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
require.Nil(t, msg.Event().GetData()["followers"])
hook.Process(msg, webConn, map[string]any{
"followers": model.StringArray{otherUserID},
})
assert.Nil(t, msg.Event().GetData()["followers"])
})
}
func TestAddMentionsAndAddFollowersHooks(t *testing.T) {
addMentionsHook := &addMentionsBroadcastHook{}
addFollowersHook := &addFollowersBroadcastHook{}
userID := model.NewId()
webConn := &platform.WebConn{
UserId: userID,
}
msg := platform.MakeHookedWebSocketEvent(model.NewWebSocketEvent(model.WebsocketEventPosted, "", "", "", nil, ""))
originalData := msg.Event().GetData()
require.Nil(t, originalData["mentions"])
require.Nil(t, originalData["followers"])
addMentionsHook.Process(msg, webConn, map[string]any{
"mentions": model.StringArray{userID},
})
addFollowersHook.Process(msg, webConn, map[string]any{
"followers": model.StringArray{userID},
})
t.Run("should be able to add both mentions and followers to a single event", func(t *testing.T) {
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["followers"])
assert.Equal(t, `["`+userID+`"]`, msg.Event().GetData()["mentions"])
})
}

View File

@ -104,6 +104,17 @@ type WebsocketBroadcast struct {
// ReliableClusterSend indicates whether or not the message should
// be sent through the cluster using the reliable, TCP backed channel.
ReliableClusterSend bool `json:"-"`
// BroadcastHooks is a slice of hooks IDs used to process events before sending them on individual connections. The
// IDs should be understood by the WebSocket code.
//
// This field should never be sent to the client.
BroadcastHooks []string `json:"broadcast_hooks,omitempty"`
// BroadcastHookArgs is a slice of named arguments for each hook invocation. The index of each entry corresponds to
// the index of a hook ID in BroadcastHooks
//
// This field should never be sent to the client.
BroadcastHookArgs []map[string]any `json:"broadcast_hook_args,omitempty"`
}
func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
@ -124,10 +135,17 @@ func (wb *WebsocketBroadcast) copy() *WebsocketBroadcast {
c.OmitConnectionId = wb.OmitConnectionId
c.ContainsSanitizedData = wb.ContainsSanitizedData
c.ContainsSensitiveData = wb.ContainsSensitiveData
c.BroadcastHooks = wb.BroadcastHooks
c.BroadcastHookArgs = wb.BroadcastHookArgs
return &c
}
func (wb *WebsocketBroadcast) AddHook(hookID string, hookArgs map[string]any) {
wb.BroadcastHooks = append(wb.BroadcastHooks, hookID)
wb.BroadcastHookArgs = append(wb.BroadcastHookArgs, hookArgs)
}
type precomputedWebSocketEventJSON struct {
Event json.RawMessage
Data json.RawMessage
@ -190,6 +208,32 @@ func (ev *WebSocketEvent) PrecomputeJSON() *WebSocketEvent {
return evCopy
}
func (ev *WebSocketEvent) RemovePrecomputedJSON() *WebSocketEvent {
evCopy := ev.DeepCopy()
evCopy.precomputedJSON = nil
return evCopy
}
// WithoutBroadcastHooks gets the broadcast hook information from a WebSocketEvent and returns the event without that.
// If the event has broadcast hooks, a copy of the event is returned. Otherwise, the original event is returned. This
// is intended to be called before the event is sent to the client.
func (ev *WebSocketEvent) WithoutBroadcastHooks() (*WebSocketEvent, []string, []map[string]any) {
hooks := ev.broadcast.BroadcastHooks
hookArgs := ev.broadcast.BroadcastHookArgs
if len(hooks) == 0 && len(hookArgs) == 0 {
return ev, hooks, hookArgs
}
evCopy := ev.Copy()
evCopy.broadcast = ev.broadcast.copy()
evCopy.broadcast.BroadcastHooks = nil
evCopy.broadcast.BroadcastHookArgs = nil
return evCopy, hooks, hookArgs
}
func (ev *WebSocketEvent) Add(key string, value any) {
ev.data[key] = value
}
@ -219,17 +263,9 @@ func (ev *WebSocketEvent) Copy() *WebSocketEvent {
}
func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
var dataCopy map[string]any
if ev.data != nil {
dataCopy = make(map[string]any, len(ev.data))
for k, v := range ev.data {
dataCopy[k] = v
}
}
evCopy := &WebSocketEvent{
event: ev.event,
data: dataCopy,
data: copyMap(ev.data),
broadcast: ev.broadcast.copy(),
sequence: ev.sequence,
precomputedJSON: ev.precomputedJSON.copy(),
@ -237,6 +273,14 @@ func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
return evCopy
}
func copyMap[K comparable, V any](m map[K]V) map[K]V {
dataCopy := make(map[K]V, len(m))
for k, v := range m {
dataCopy[k] = v
}
return dataCopy
}
func (ev *WebSocketEvent) GetData() map[string]any {
return ev.data
}