MM-58577 Check remote ownership for posts and reactions (#27317)

* - ensure that posts and reactions can only be added via sync when coming from a remote that the target channel is shared with.
- ensure that posts and reactions are only modified/deleted by the remote that owns them.

* check that reaction belongs to post that belongs to channel that is shared with remote;  check that posts belong to channel shared with remote

* check for correct error type in unit test

* tweak unit test
This commit is contained in:
Doug Lauder 2024-06-11 11:51:00 -04:00 committed by GitHub
parent 6f8de3449a
commit 594ba6e665
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 276 additions and 26 deletions

View File

@ -7684,6 +7684,24 @@ func (s *OpenTracingLayerReactionStore) GetForPostSince(postId string, since int
return result, err
}
func (s *OpenTracingLayerReactionStore) GetSingle(userID string, postID string, remoteID string, emojiName string) (*model.Reaction, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ReactionStore.GetSingle")
s.Root.Store.SetContext(newCtx)
defer func() {
s.Root.Store.SetContext(origCtx)
}()
defer span.Finish()
result, err := s.ReactionStore.GetSingle(userID, postID, remoteID, emojiName)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
return result, err
}
func (s *OpenTracingLayerReactionStore) GetUniqueCountForPost(postId string) (int, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ReactionStore.GetUniqueCountForPost")

View File

@ -8741,6 +8741,27 @@ func (s *RetryLayerReactionStore) GetForPostSince(postId string, since int64, ex
}
func (s *RetryLayerReactionStore) GetSingle(userID string, postID string, remoteID string, emojiName string) (*model.Reaction, error) {
tries := 0
for {
result, err := s.ReactionStore.GetSingle(userID, postID, remoteID, emojiName)
if err == nil {
return result, nil
}
if !isRepeatableError(err) {
return result, err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
return result, err
}
timepkg.Sleep(100 * timepkg.Millisecond)
}
}
func (s *RetryLayerReactionStore) GetUniqueCountForPost(postId string) (int, error) {
tries := 0

View File

@ -5,6 +5,7 @@ package sqlstore
import (
"database/sql"
"fmt"
"time"
sq "github.com/mattermost/squirrel"
@ -198,6 +199,33 @@ func (s *SqlReactionStore) BulkGetForPosts(postIds []string) ([]*model.Reaction,
return reactions, nil
}
func (s *SqlReactionStore) GetSingle(userID, postID, remoteID, emojiName string) (*model.Reaction, error) {
query := s.getQueryBuilder().
Select("UserId", "PostId", "EmojiName", "CreateAt",
"COALESCE(UpdateAt, CreateAt) As UpdateAt", "COALESCE(DeleteAt, 0) As DeleteAt",
"RemoteId", "ChannelId").
From("Reactions").
Where(sq.Eq{"UserId": userID}).
Where(sq.Eq{"PostId": postID}).
Where(sq.Eq{"COALESCE(RemoteId, '')": remoteID}).
Where(sq.Eq{"EmojiName": emojiName})
queryString, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "reactions_getsingle_tosql")
}
var reactions []*model.Reaction
if err := s.GetReplicaX().Select(&reactions, queryString, args...); err != nil {
return nil, errors.Wrapf(err, "failed to find reaction")
}
if len(reactions) == 0 {
return nil, store.NewErrNotFound("Reaction", fmt.Sprintf("user_id=%s, post_id=%s, remote_id=%s, emoji_name=%s",
userID, postID, remoteID, emojiName))
}
return reactions[0], nil
}
func (s *SqlReactionStore) DeleteAllWithEmojiName(emojiName string) error {
var reactions []*model.Reaction
now := model.GetMillis()

View File

@ -740,6 +740,7 @@ type ReactionStore interface {
ExistsOnPost(postId string, emojiName string) (bool, error)
DeleteAllWithEmojiName(emojiName string) error
BulkGetForPosts(postIds []string) ([]*model.Reaction, error)
GetSingle(userID, postID, remoteID, emojiName string) (*model.Reaction, error)
DeleteOrphanedRowsByIds(r *model.RetentionIdsForDeletion) error
PermanentDeleteBatch(endTime int64, limit int64) (int64, error)
PermanentDeleteByUser(userID string) error

View File

@ -198,6 +198,36 @@ func (_m *ReactionStore) GetForPostSince(postId string, since int64, excludeRemo
return r0, r1
}
// GetSingle provides a mock function with given fields: userID, postID, remoteID, emojiName
func (_m *ReactionStore) GetSingle(userID string, postID string, remoteID string, emojiName string) (*model.Reaction, error) {
ret := _m.Called(userID, postID, remoteID, emojiName)
if len(ret) == 0 {
panic("no return value specified for GetSingle")
}
var r0 *model.Reaction
var r1 error
if rf, ok := ret.Get(0).(func(string, string, string, string) (*model.Reaction, error)); ok {
return rf(userID, postID, remoteID, emojiName)
}
if rf, ok := ret.Get(0).(func(string, string, string, string) *model.Reaction); ok {
r0 = rf(userID, postID, remoteID, emojiName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Reaction)
}
}
if rf, ok := ret.Get(1).(func(string, string, string, string) error); ok {
r1 = rf(userID, postID, remoteID, emojiName)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetUniqueCountForPost provides a mock function with given fields: postId
func (_m *ReactionStore) GetUniqueCountForPost(postId string) (int, error) {
ret := _m.Called(postId)

View File

@ -31,6 +31,7 @@ func TestReactionStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStor
t.Run("ReactionDeadlock", func(t *testing.T) { testReactionDeadlock(t, rctx, ss) })
t.Run("ExistsOnPost", func(t *testing.T) { testExistsOnPost(t, rctx, ss) })
t.Run("GetUniqueCountForPost", func(t *testing.T) { testGetUniqueCountForPost(t, rctx, ss) })
t.Run("ReactionGetSingle", func(t *testing.T) { testReactionGetSingle(t, rctx, ss) })
}
func testReactionSave(t *testing.T, rctx request.CTX, ss store.Store) {
@ -938,3 +939,85 @@ func testGetUniqueCountForPost(t *testing.T, rctx request.CTX, ss store.Store) {
require.NoError(t, err)
require.Equal(t, 2, count)
}
func testReactionGetSingle(t *testing.T, rctx request.CTX, ss store.Store) {
var (
testUserID = model.NewId()
testEmojiName = "smile"
testRemoteID = model.NewId()
)
t.Run("get without remoteId", func(t *testing.T) {
post, err := ss.Post().Save(rctx, &model.Post{
ChannelId: model.NewId(),
UserId: testUserID,
})
require.NoError(t, err)
reaction := &model.Reaction{
UserId: testUserID,
PostId: post.Id,
EmojiName: testEmojiName,
}
_, nErr := ss.Reaction().Save(reaction)
require.NoError(t, nErr)
reactionFound, err := ss.Reaction().GetSingle(testUserID, post.Id, "", testEmojiName)
require.NoError(t, err)
assert.Equal(t, testUserID, reactionFound.UserId)
assert.Equal(t, post.Id, reactionFound.PostId)
assert.Equal(t, "", reactionFound.GetRemoteID())
assert.Equal(t, testEmojiName, reactionFound.EmojiName)
})
t.Run("get with remoteId", func(t *testing.T) {
post, err := ss.Post().Save(rctx, &model.Post{
ChannelId: model.NewId(),
UserId: testUserID,
})
require.NoError(t, err)
reaction := &model.Reaction{
UserId: testUserID,
PostId: post.Id,
EmojiName: testEmojiName,
RemoteId: model.NewString(testRemoteID),
}
_, nErr := ss.Reaction().Save(reaction)
require.NoError(t, nErr)
reactionFound, err := ss.Reaction().GetSingle(testUserID, post.Id, testRemoteID, testEmojiName)
require.NoError(t, err)
assert.Equal(t, testUserID, reactionFound.UserId)
assert.Equal(t, post.Id, reactionFound.PostId)
assert.Equal(t, testRemoteID, reactionFound.GetRemoteID())
assert.Equal(t, testEmojiName, reactionFound.EmojiName)
})
t.Run("not found - wrong remoteID", func(t *testing.T) {
post, err := ss.Post().Save(rctx, &model.Post{
ChannelId: model.NewId(),
UserId: testUserID,
})
require.NoError(t, err)
reaction := &model.Reaction{
UserId: testUserID,
PostId: post.Id,
EmojiName: testEmojiName,
RemoteId: model.NewString(testRemoteID),
}
_, nErr := ss.Reaction().Save(reaction)
require.NoError(t, nErr)
reactionFound, err := ss.Reaction().GetSingle(testUserID, post.Id, "bogus-remoteId", testEmojiName)
require.Error(t, err)
assert.Nil(t, reactionFound)
var errNotFound *store.ErrNotFound
assert.ErrorAs(t, err, &errNotFound)
})
}

View File

@ -6937,6 +6937,22 @@ func (s *TimerLayerReactionStore) GetForPostSince(postId string, since int64, ex
return result, err
}
func (s *TimerLayerReactionStore) GetSingle(userID string, postID string, remoteID string, emojiName string) (*model.Reaction, error) {
start := time.Now()
result, err := s.ReactionStore.GetSingle(userID, postID, remoteID, emojiName)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
success := "false"
if err == nil {
success = "true"
}
s.Root.Metrics.ObserveStoreMethodDuration("ReactionStore.GetSingle", success, elapsed)
}
return result, err
}
func (s *TimerLayerReactionStore) GetUniqueCountForPost(postId string) (int, error) {
start := time.Now()

View File

@ -18,7 +18,8 @@ import (
)
var (
ErrRemoteIDMismatch = errors.New("remoteID mismatch")
ErrRemoteIDMismatch = errors.New("remoteID mismatch")
ErrChannelIDMismatch = errors.New("channelID mismatch")
)
func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.RemoteCluster, response *remotecluster.Response) error {
@ -46,7 +47,7 @@ func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.R
}
func (scs *Service) processSyncMessage(c request.CTX, syncMsg *model.SyncMsg, rc *model.RemoteCluster, response *remotecluster.Response) error {
var channel *model.Channel
var targetChannel *model.Channel
var team *model.Team
var err error
@ -65,14 +66,23 @@ func (scs *Service) processSyncMessage(c request.CTX, syncMsg *model.SyncMsg, rc
mlog.Int("reaction_count", len(syncMsg.Reactions)),
)
if channel, err = scs.server.GetStore().Channel().Get(syncMsg.ChannelId, true); err != nil {
if targetChannel, err = scs.server.GetStore().Channel().Get(syncMsg.ChannelId, true); err != nil {
// if the channel doesn't exist then none of these sync items are going to work.
return fmt.Errorf("channel not found processing sync message: %w", err)
}
// make sure target channel is shared with the remote
exists, err := scs.server.GetStore().SharedChannel().HasRemote(targetChannel.Id, rc.RemoteId)
if err != nil {
return fmt.Errorf("cannot check channel share state for sync message: %w", err)
}
if !exists {
return fmt.Errorf("cannot process sync message; channel not shared with remote: %w", ErrRemoteIDMismatch)
}
// add/update users before posts
for _, user := range syncMsg.Users {
if userSaved, err := scs.upsertSyncUser(c, user, channel, rc); err != nil {
if userSaved, err := scs.upsertSyncUser(c, user, targetChannel, rc); err != nil {
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync user",
mlog.String("remote", rc.Name),
mlog.String("channel_id", syncMsg.ChannelId),
@ -103,7 +113,7 @@ func (scs *Service) processSyncMessage(c request.CTX, syncMsg *model.SyncMsg, rc
continue
}
if channel.Type != model.ChannelTypeDirect && team == nil {
if targetChannel.Type != model.ChannelTypeDirect && team == nil {
var err2 error
team, err2 = scs.server.GetStore().Channel().GetTeamForChannel(syncMsg.ChannelId)
if err2 != nil {
@ -124,7 +134,7 @@ func (scs *Service) processSyncMessage(c request.CTX, syncMsg *model.SyncMsg, rc
}
// add/update post
rpost, err := scs.upsertSyncPost(post, channel, rc)
rpost, err := scs.upsertSyncPost(post, targetChannel, rc)
if err != nil {
syncResp.PostErrors = append(syncResp.PostErrors, post.Id)
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync post",
@ -140,7 +150,7 @@ func (scs *Service) processSyncMessage(c request.CTX, syncMsg *model.SyncMsg, rc
// add/remove reactions
for _, reaction := range syncMsg.Reactions {
if _, err := scs.upsertSyncReaction(reaction, rc); err != nil {
if _, err := scs.upsertSyncReaction(reaction, targetChannel, rc); err != nil {
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Error upserting sync reaction",
mlog.String("remote", rc.Name),
mlog.String("user_id", reaction.UserId),
@ -189,12 +199,12 @@ func (scs *Service) upsertSyncUser(c request.CTX, user *model.User, channel *mod
}
} else {
// existing user. Make sure user belongs to the remote that issued the update
if SafeString(euser.RemoteId) != rc.RemoteId {
if euser.GetRemoteID() != rc.RemoteId {
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "RemoteID mismatch sync'ing user",
mlog.String("remote", rc.Name),
mlog.String("user_id", user.Id),
mlog.String("existing_user_remote_id", SafeString(euser.RemoteId)),
mlog.String("update_user_remote_id", SafeString(user.RemoteId)),
mlog.String("existing_user_remote_id", euser.GetRemoteID()),
mlog.String("update_user_remote_id", user.GetRemoteID()),
)
return nil, fmt.Errorf("error updating user: %w", ErrRemoteIDMismatch)
}
@ -337,7 +347,7 @@ func (scs *Service) updateSyncUser(rctx request.CTX, patch *model.UserPatch, use
return nil, fmt.Errorf("error updating sync user %s: %w", user.Id, err)
}
func (scs *Service) upsertSyncPost(post *model.Post, channel *model.Channel, rc *model.RemoteCluster) (*model.Post, error) {
func (scs *Service) upsertSyncPost(post *model.Post, targetChannel *model.Channel, rc *model.RemoteCluster) (*model.Post, error) {
var appErr *model.AppError
post.RemoteId = model.NewString(rc.RemoteId)
@ -350,9 +360,24 @@ func (scs *Service) upsertSyncPost(post *model.Post, channel *model.Channel, rc
}
}
// ensure the post is in the target channel. This ensures the post can only be associated with a channel
// that is shared with the remote.
if post.ChannelId != targetChannel.Id || (rpost != nil && rpost.ChannelId != targetChannel.Id) {
return nil, fmt.Errorf("post sync failed: %w", ErrChannelIDMismatch)
}
if rpost == nil {
// post doesn't exist; create new one
rpost, appErr = scs.app.CreatePost(rctx, post, channel, true, true)
// post doesn't exist; check that user belongs to remote and create post.
// user is not checked for edit/delete because admins can perform those actions
user, err := scs.server.GetStore().User().Get(context.TODO(), post.UserId)
if err != nil {
return nil, fmt.Errorf("error fetching user for post sync: %w", err)
}
if user.GetRemoteID() != rc.RemoteId {
return nil, fmt.Errorf("post sync failed: %w", ErrRemoteIDMismatch)
}
rpost, appErr = scs.app.CreatePost(rctx, post, targetChannel, true, true)
if appErr == nil {
scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Created sync post",
mlog.String("post_id", post.Id),
@ -392,21 +417,49 @@ func (scs *Service) upsertSyncPost(post *model.Post, channel *model.Channel, rc
return rpost, rerr
}
func (scs *Service) upsertSyncReaction(reaction *model.Reaction, rc *model.RemoteCluster) (*model.Reaction, error) {
func (scs *Service) upsertSyncReaction(reaction *model.Reaction, targetChannel *model.Channel, rc *model.RemoteCluster) (*model.Reaction, error) {
savedReaction := reaction
var appErr *model.AppError
reaction.RemoteId = model.NewString(rc.RemoteId)
// check that the reaction's post is in the target channel. This ensures the reaction can only be associated with a post
// that is in a channel shared with the remote.
rctx := request.EmptyContext(scs.server.Log())
post, err := scs.server.GetStore().Post().GetSingle(rctx, reaction.PostId, true)
if err != nil {
return nil, fmt.Errorf("error fetching post for reaction sync: %w", err)
}
if post.ChannelId != targetChannel.Id {
return nil, fmt.Errorf("reaction sync failed: %w", ErrChannelIDMismatch)
}
if reaction.DeleteAt == 0 {
existingReaction, err := scs.server.GetStore().Reaction().GetSingle(reaction.UserId, reaction.PostId, rc.RemoteId, reaction.EmojiName)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("error fetching reaction for sync: %w", err)
}
if existingReaction == nil {
// reaction does not exist; check that user belongs to remote and create reaction
// this is not done for delete since deletion can be done by admins on the remote
user, err := scs.server.GetStore().User().Get(context.TODO(), reaction.UserId)
if err != nil {
return nil, fmt.Errorf("error fetching user for reaction sync: %w", err)
}
if user.GetRemoteID() != rc.RemoteId {
return nil, fmt.Errorf("reaction sync failed: %w", ErrRemoteIDMismatch)
}
reaction.RemoteId = model.NewString(rc.RemoteId)
savedReaction, appErr = scs.app.SaveReactionForPost(request.EmptyContext(scs.server.Log()), reaction)
} else {
// make sure the reaction being deleted is owned by the remote
if existingReaction.GetRemoteID() != rc.RemoteId {
return nil, fmt.Errorf("reaction sync failed: %w", ErrRemoteIDMismatch)
}
appErr = scs.app.DeleteReactionForPost(request.EmptyContext(scs.server.Log()), reaction)
}
var err error
var retErr error
if appErr != nil {
err = errors.New(appErr.Error())
retErr = errors.New(appErr.Error())
}
return savedReaction, err
return savedReaction, retErr
}

View File

@ -148,10 +148,3 @@ func reducePostsSliceInCache(posts []*model.Post, cache map[string]*model.Post)
}
return reduced
}
func SafeString(p *string) string {
if p == nil {
return ""
}
return *p
}

View File

@ -64,3 +64,10 @@ func (o *Reaction) PreUpdate() {
o.RemoteId = NewString("")
}
}
func (o *Reaction) GetRemoteID() string {
if o.RemoteId == nil {
return ""
}
return *o.RemoteId
}