Revert "MM-42810: Introduce a channel hook for a websocket event (#23812)" (#24107)

Automatic Merge
This commit is contained in:
Agniva De Sarker 2023-07-24 21:46:57 +05:30 committed by GitHub
parent e70abd6e0f
commit 29bd0c9357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 155 additions and 211 deletions

View File

@ -71,17 +71,6 @@ func TestHasPermissionToTeam(t *testing.T) {
assert.True(t, th.App.HasPermissionToTeam(th.SystemAdminUser.Id, th.BasicTeam.Id, model.PermissionListTeamChannels))
}
func TestSessionHasPermissionToReadChannel(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
assert.True(t, th.App.HasPermissionToReadChannel(th.Context, th.BasicUser.Id, th.BasicChannel))
pc1 := th.CreatePrivateChannel(th.Context, th.BasicTeam)
assert.False(t, th.App.HasPermissionToReadChannel(th.Context, th.BasicUser2.Id, pc1))
th.AddUserToChannel(th.BasicUser2, pc1)
assert.True(t, th.App.HasPermissionToReadChannel(th.Context, th.BasicUser2.Id, pc1))
}
func TestSessionHasPermissionToChannel(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()

View File

@ -484,6 +484,12 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea
message := model.NewWebSocketEvent(model.WebsocketEventPosted, "", post.ChannelId, "", nil, "")
// Note that PreparePostForClient should've already been called by this point
postJSON, jsonErr := post.ToJSON()
if jsonErr != nil {
return nil, errors.Wrapf(jsonErr, "failed to encode post to JSON")
}
message.Add("post", postJSON)
message.Add("channel_type", channel.Type)
message.Add("channel_display_name", notification.GetChannelName(model.ShowUsername, ""))
message.Add("channel_name", channel.Name)
@ -517,10 +523,13 @@ func (a *App) SendNotifications(c request.CTX, post *model.Post, team *model.Tea
message.Add("followers", model.ArrayToJSON(notificationsForCRT.Desktop))
}
err := a.publishWebsocketEventForPost(c, post, message)
published, err := a.publishWebsocketEventForPermalinkPost(c, post, message)
if err != nil {
return nil, err
}
if !published {
a.Publish(message)
}
// If this is a reply in a thread, notify participants
if isCRTAllowed && post.RootId != "" {

View File

@ -761,17 +761,6 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
}
}
// The priority checks in order of specificity are:
// ConnectionId
// OmitConnectionId
//
// UserId
// OmitUserId
//
// ChannelId - is member of channel
// TeamId - is member of team
// Guest - does guest have access
// If the event is destined to a specific connection
if msg.GetBroadcast().ConnectionId != "" {
return wc.GetConnectionID() == msg.GetBroadcast().ConnectionId
@ -800,16 +789,6 @@ func (wc *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
wc.lastAllChannelMembersTime = 0
}
// Execute channel hook
if msg.GetBroadcast().ChannelHook != nil {
hasChange := msg.GetBroadcast().ChannelHook(wc.UserId, msg)
if hasChange {
// If hook returns true, that means message has been modified. We need
// to wipe off the pre-computed JSON
msg.RemovePrecomputedJSON()
}
}
if wc.allChannelMembers == nil {
result, err := wc.Platform.Store.Channel().GetAllChannelMembersForUser(wc.UserId, false, false)
if err != nil {

View File

@ -9,7 +9,6 @@ import (
"net/http"
"net/http/httptest"
"runtime"
"sync/atomic"
"testing"
"time"
@ -127,81 +126,6 @@ func TestHubStopRaceCondition(t *testing.T) {
}
}
func TestBroadcastChannelHook(t *testing.T) {
th := SetupWithStoreMock(t)
sess1 := &model.Session{
Id: "id1",
UserId: "user1",
DeviceId: "",
Token: "sesstoken",
ExpiresAt: model.GetMillis() + 300000,
LastActivityAt: 10000,
}
mockStore := th.Service.Store.(*mocks.Store)
mockUserStore := mocks.UserStore{}
mockUserStore.On("Count", mock.Anything).Return(int64(10), nil)
mockUserStore.On("GetUnreadCount", mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(int64(1), nil)
mockPostStore := mocks.PostStore{}
mockPostStore.On("GetMaxPostSize").Return(65535, nil)
mockSystemStore := mocks.SystemStore{}
mockSystemStore.On("GetByName", "UpgradedFromTE").Return(&model.System{Name: "UpgradedFromTE", Value: "false"}, nil)
mockSystemStore.On("GetByName", "InstallationDate").Return(&model.System{Name: "InstallationDate", Value: "10"}, nil)
mockSystemStore.On("GetByName", "FirstServerRunTimestamp").Return(&model.System{Name: "FirstServerRunTimestamp", Value: "10"}, nil)
mockSessionStore := mocks.SessionStore{}
mockSessionStore.On("UpdateLastActivityAt", "id1", mock.Anything).Return(nil)
mockSessionStore.On("Save", mock.AnythingOfType("*model.Session")).Return(sess1, nil)
mockSessionStore.On("Get", mock.Anything, "id1").Return(sess1, nil)
mockSessionStore.On("Remove", "id1").Return(nil)
mockStatusStore := mocks.StatusStore{}
mockStatusStore.On("Get", "user1").Return(&model.Status{UserId: "user1", Status: model.StatusOnline}, nil)
mockStatusStore.On("UpdateLastActivityAt", "user1", mock.Anything).Return(nil)
mockStatusStore.On("SaveOrUpdate", mock.AnythingOfType("*model.Status")).Return(nil)
mockOAuthStore := mocks.OAuthStore{}
mockChannelStore := mocks.ChannelStore{}
mockStore.On("Session").Return(&mockSessionStore)
mockStore.On("OAuth").Return(&mockOAuthStore)
mockStore.On("Status").Return(&mockStatusStore)
mockStore.On("User").Return(&mockUserStore)
mockStore.On("Post").Return(&mockPostStore)
mockStore.On("System").Return(&mockSystemStore)
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("GetDBSchemaVersion").Return(1, nil)
s := httptest.NewServer(dummyWebsocketHandler(t))
defer s.Close()
session, err := th.Service.CreateSession(&model.Session{
UserId: "testid",
})
require.NoError(t, err)
wc1 := registerDummyWebConn(t, th, s.Listener.Addr(), session)
wc1.SetConnectionID("connID")
hub := th.Service.GetHubForUserId(wc1.UserId)
mockChannelStore.On("GetAllChannelMembersForUser", wc1.UserId, false, false).Return(map[string]string{"channelID": "test"}, nil)
ev := model.NewWebSocketEvent("", "", "channelID", "", nil, "")
broadcast := ev.GetBroadcast()
var test atomic.Bool
broadcast.ChannelHook = func(_ string, ev *model.WebSocketEvent) bool {
test.Store(true)
return true
}
ev.SetBroadcast(broadcast)
hub.Broadcast(ev)
// Wait until the goroutines from NewWebConn are finished.
th.Service.waitForGoroutines()
th.TearDown()
assert.Equal(t, true, test.Load())
}
func TestHubSessionRevokeRace(t *testing.T) {
th := SetupWithStoreMock(t)
defer th.TearDown()

View File

@ -597,11 +597,12 @@ func (a *App) UpdateEphemeralPost(c request.CTX, userID string, post *model.Post
message := model.NewWebSocketEvent(model.WebsocketEventPostEdited, "", post.ChannelId, userID, nil, "")
post = a.PreparePostForClientWithEmbedsAndImages(c, post, true, false, true)
post = model.AddPostActionCookies(post, a.PostActionCookieSecret())
appErr := a.publishWebsocketEventForPost(c, post, message)
if appErr != nil {
mlog.Warn("Failed to send websocket event for ephemeral post", mlog.Err(appErr))
postJSON, jsonErr := post.ToJSON()
if jsonErr != nil {
mlog.Warn("Failed to encode post to JSON", mlog.Err(jsonErr))
}
message.Add("post", postJSON)
a.Publish(message)
return post
}
@ -741,95 +742,85 @@ func (a *App) UpdatePost(c *request.Context, post *model.Post, safeUpdate bool)
}
message := model.NewWebSocketEvent(model.WebsocketEventPostEdited, "", rpost.ChannelId, "", nil, "")
postJSON, jsonErr := rpost.ToJSON()
if jsonErr != nil {
return nil, model.NewAppError("UpdatePost", "app.post.marshal.app_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
}
message.Add("post", postJSON)
err = a.publishWebsocketEventForPost(c, rpost, message)
published, err := a.publishWebsocketEventForPermalinkPost(c, rpost, message)
if err != nil {
return nil, err
}
if !published {
a.Publish(message)
}
a.invalidateCacheForChannelPosts(rpost.ChannelId)
return rpost, nil
}
// publishWebsocketEventForPost publishes the websocket event only for post create/edit.
// The cases of post delete/unread does not need special handling as they don't bother
// with the post content.
//
// This method assumes that if there's a permalink, it's already attached to the post.
// If the user doesn't have access then this method will wipe that off.
func (a *App) publishWebsocketEventForPost(c request.CTX, post *model.Post, message *model.WebSocketEvent) (appErr *model.AppError) {
postJSON, jsonErr := post.ToJSON()
if jsonErr != nil {
return model.NewAppError("publishWebsocketEventForPost", "app.post.marshal.app_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
func (a *App) publishWebsocketEventForPermalinkPost(c request.CTX, post *model.Post, message *model.WebSocketEvent) (published bool, err *model.AppError) {
var previewedPostID string
if val, ok := post.GetProp(model.PostPropsPreviewedPost).(string); ok {
previewedPostID = val
} else {
return false, nil
}
message.Add("post", postJSON)
defer func() {
if appErr == nil {
a.Publish(message)
if !model.IsValidId(previewedPostID) {
mlog.Warn("invalid post prop value", mlog.String("prop_key", model.PostPropsPreviewedPost), mlog.String("prop_value", previewedPostID))
return false, nil
}
previewedPost, err := a.GetSinglePost(previewedPostID, false)
if err != nil {
if err.StatusCode == http.StatusNotFound {
mlog.Warn("permalinked post not found", mlog.String("referenced_post_id", previewedPostID))
return false, nil
}
}()
return false, err
}
channelMembers, err := a.GetChannelMembersPage(c, post.ChannelId, 0, 10000000)
if err != nil {
return false, err
}
permalinkPreviewedChannel, err := a.GetChannel(c, previewedPost.ChannelId)
if err != nil {
if err.StatusCode == http.StatusNotFound {
mlog.Warn("channel containing permalinked post not found", mlog.String("referenced_channel_id", previewedPost.ChannelId))
return false, nil
}
return false, err
}
permalinkPreviewedPost := post.GetPreviewPost()
if permalinkPreviewedPost == nil {
return nil
}
if !model.IsValidId(permalinkPreviewedPost.PostID) {
mlog.Warn("invalid preview post ID", mlog.String("prop_value", permalinkPreviewedPost.PostID))
return nil
}
// To remain secure by default, we wipe out the metadata unconditionally.
post.Metadata.Embeds[0].Data = nil
postWithoutPermalinkPreviewJSON, err := post.ToJSON()
if err != nil {
return model.NewAppError("publishWebsocketEventForPost", "app.post.marshal.app_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
}
var previewedPost *model.Post
previewedPost, appErr = a.GetSinglePost(permalinkPreviewedPost.PostID, false)
if appErr != nil {
if appErr.StatusCode == http.StatusNotFound {
mlog.Warn("permalinked post not found", mlog.String("referenced_post_id", permalinkPreviewedPost.PostID))
return nil
}
return appErr
}
var permalinkPreviewedChannel *model.Channel
permalinkPreviewedChannel, appErr = a.GetChannel(c, previewedPost.ChannelId)
if appErr != nil {
if appErr.StatusCode == http.StatusNotFound {
mlog.Warn("channel containing permalinked post not found", mlog.String("referenced_channel_id", previewedPost.ChannelId))
return nil
}
return appErr
}
// In case the user does have permission to read, we set the metadata back.
// Note that this is the return value to the post creator, and has nothing to do
// with the content of the websocket broadcast to that user or any other.
if a.HasPermissionToReadChannel(c, post.UserId, permalinkPreviewedChannel) {
post.Metadata.Embeds[0].Data = permalinkPreviewedPost
}
broadcastCopy := message.GetBroadcast()
broadcastCopy.ChannelHook = func(userID string, ev *model.WebSocketEvent) bool {
if a.HasPermissionToReadChannel(c, userID, permalinkPreviewedChannel) {
// If there is no change, then the original post which was attached
// (at the start of the method) will get sent.
return false
for _, cm := range channelMembers {
if permalinkPreviewedPost != nil {
post.Metadata.Embeds[0].Data = permalinkPreviewedPost
}
ev.AddWithCopy("post", postWithoutPermalinkPreviewJSON)
return true
postForUser := a.sanitizePostMetadataForUserAndChannel(c, post, permalinkPreviewedPost, permalinkPreviewedChannel, cm.UserId)
// Using DeepCopy here to avoid a race condition
// between publishing the event and setting the "post" data value below.
messageCopy := message.DeepCopy()
broadcastCopy := messageCopy.GetBroadcast()
broadcastCopy.UserId = cm.UserId
messageCopy.SetBroadcast(broadcastCopy)
postJSON, jsonErr := postForUser.ToJSON()
if jsonErr != nil {
mlog.Warn("Failed to encode post to JSON", mlog.Err(jsonErr))
}
messageCopy.Add("post", postJSON)
a.Publish(messageCopy)
}
message.SetBroadcast(broadcastCopy)
return nil
return true, nil
}
func (a *App) PatchPost(c *request.Context, postID string, patch *model.PostPatch) (*model.Post, *model.AppError) {

View File

@ -190,6 +190,18 @@ func (a *App) getEmbedsAndImages(c request.CTX, post *model.Post, isNewPost bool
return post
}
func (a *App) sanitizePostMetadataForUserAndChannel(c request.CTX, post *model.Post, previewedPost *model.PreviewPost, previewedChannel *model.Channel, userID string) *model.Post {
if post.Metadata == nil || len(post.Metadata.Embeds) == 0 || previewedPost == nil {
return post
}
if previewedChannel != nil && !a.HasPermissionToReadChannel(c, userID, previewedChannel) {
post.Metadata.Embeds[0].Data = nil
}
return post
}
func (a *App) SanitizePostMetadataForUser(c request.CTX, post *model.Post, userID string) (*model.Post, *model.AppError) {
if post.Metadata == nil || len(post.Metadata.Embeds) == 0 {
return post, nil

View File

@ -2765,6 +2765,58 @@ func TestContainsPermalink(t *testing.T) {
}
}
func TestSanitizePostMetadataForUserAndChannel(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.EnableLinkPreviews = true
*cfg.ServiceSettings.SiteURL = "http://mymattermost.com"
})
directChannel, err := th.App.createDirectChannel(th.Context, th.BasicUser.Id, th.BasicUser2.Id)
assert.Nil(t, err)
userID := model.NewId()
post := &model.Post{
Id: userID,
Metadata: &model.PostMetadata{
Embeds: []*model.PostEmbed{
{
Type: model.PostEmbedOpengraph,
URL: "ogURL",
Data: &opengraph.OpenGraph{
Images: []*ogimage.Image{
{
URL: "imageURL",
},
},
},
},
},
},
}
previewedPost := model.NewPreviewPost(post, th.BasicTeam, directChannel)
actual := th.App.sanitizePostMetadataForUserAndChannel(th.Context, post, previewedPost, directChannel, th.BasicUser2.Id)
assert.NotNil(t, actual.Metadata.Embeds[0].Data)
guestID := model.NewId()
guest := &model.User{
Email: "success+" + guestID + "@simulator.amazonses.com",
Username: "un_" + guestID,
Nickname: "nn_" + guestID,
Password: "Password1",
EmailVerified: true,
}
guest, appErr := th.App.CreateGuest(th.Context, guest)
require.Nil(t, appErr)
actual = th.App.sanitizePostMetadataForUserAndChannel(th.Context, post, previewedPost, directChannel, guest.Id)
assert.Nil(t, actual.Metadata.Embeds[0].Data)
}
func TestSanitizePostMetaDataForAudit(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()

View File

@ -92,15 +92,14 @@ type WebSocketMessage interface {
}
type WebsocketBroadcast struct {
ConnectionId string `json:"connection_id"` // broadcast only occurs for this connection
OmitConnectionId string `json:"omit_connection_id"` // broadcast is omitted for this connection
UserId string `json:"user_id"` // broadcast only occurs for this user
OmitUsers map[string]bool `json:"omit_users"` // broadcast is omitted for users listed here
ChannelId string `json:"channel_id"` // broadcast only occurs for users in this channel
ChannelHook func(userID string, ev *WebSocketEvent) bool `json:"-"` // ChannelHook is a function that runs for a channel scoped event. It can be used to modify the event payload based on some custom logic that runs only for connected users. The return value indicates whether the websocket event was modified or not.
TeamId string `json:"team_id"` // broadcast only occurs for users in this team
ContainsSanitizedData bool `json:"contains_sanitized_data,omitempty"` // broadcast only occurs for non-sysadmins
ContainsSensitiveData bool `json:"contains_sensitive_data,omitempty"` // broadcast only occurs for sysadmins
OmitUsers map[string]bool `json:"omit_users"` // broadcast is omitted for users listed here
UserId string `json:"user_id"` // broadcast only occurs for this user
ChannelId string `json:"channel_id"` // broadcast only occurs for users in this channel
TeamId string `json:"team_id"` // broadcast only occurs for users in this team
ConnectionId string `json:"connection_id"` // broadcast only occurs for this connection
OmitConnectionId string `json:"omit_connection_id"` // broadcast is omitted for this connection
ContainsSanitizedData bool `json:"contains_sanitized_data,omitempty"` // broadcast only occurs for non-sysadmins
ContainsSensitiveData bool `json:"contains_sensitive_data,omitempty"` // broadcast only occurs for sysadmins
// ReliableClusterSend indicates whether or not the message should
// be sent through the cluster using the reliable, TCP backed channel.
ReliableClusterSend bool `json:"-"`
@ -190,21 +189,10 @@ func (ev *WebSocketEvent) PrecomputeJSON() *WebSocketEvent {
return evCopy
}
func (ev *WebSocketEvent) RemovePrecomputedJSON() {
ev.precomputedJSON = nil
}
func (ev *WebSocketEvent) Add(key string, value any) {
ev.data[key] = value
}
// AddWithCopy copies the map and writes to a copy of that,
// and sets the map to the new event.
func (ev *WebSocketEvent) AddWithCopy(key string, value any) {
ev.data = copyMap(ev.data)
ev.data[key] = value
}
func NewWebSocketEvent(event, teamId, channelId, userId string, omitUsers map[string]bool, omitConnectionId string) *WebSocketEvent {
return &WebSocketEvent{
event: event,
@ -230,9 +218,17 @@ func (ev *WebSocketEvent) Copy() *WebSocketEvent {
}
func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
var dataCopy map[string]any
if ev.data != nil {
dataCopy = make(map[string]any, len(ev.data))
for k, v := range ev.data {
dataCopy[k] = v
}
}
evCopy := &WebSocketEvent{
event: ev.event,
data: copyMap(ev.data),
data: dataCopy,
broadcast: ev.broadcast.copy(),
sequence: ev.sequence,
precomputedJSON: ev.precomputedJSON.copy(),
@ -240,14 +236,6 @@ func (ev *WebSocketEvent) DeepCopy() *WebSocketEvent {
return evCopy
}
func copyMap[K comparable, V any](m map[K]V) map[K]V {
dataCopy := make(map[K]V, len(m))
for k, v := range m {
dataCopy[k] = v
}
return dataCopy
}
func (ev *WebSocketEvent) GetData() map[string]any {
return ev.data
}