MM-57873 Check user visibility when accepting channel invites for DMs (#27331)

* check user visibility when accepting channel invites for DMs

* stronger visibility checking for DM users

* check for correct remoteid for remote user in DM invite

* fix unit test

---------

Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Doug Lauder 2024-06-13 05:40:31 -04:00 committed by GitHub
parent 9733694854
commit bf8ddb4bdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 167 additions and 52 deletions

View File

@ -226,7 +226,7 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
if invite.Type == model.ChannelTypeDirect {
return scs.createDirectChannel(invite)
return scs.createDirectChannel(invite, rc)
}
channelNew := &model.Channel{
@ -250,14 +250,62 @@ func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.Rem
return channel, nil
}
func (scs *Service) createDirectChannel(invite channelInviteMsg) (*model.Channel, error) {
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
if len(invite.DirectParticipantIDs) != 2 {
return nil, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
}
channel, err := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), invite.DirectParticipantIDs[0], invite.DirectParticipantIDs[1], model.WithID(invite.ChannelId))
user1, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[0])
if err != nil {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, err)
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
}
user2, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[1])
if err != nil {
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
}
// determine the remote user
// - if both are remote then the DM channel does not belong on this server
// - if neither are remote then the DM channel should not be created via sync message
// - if only one is remote then we check visibility relative to that user
userRemote := user1
userLocal := user2
if !userRemote.IsRemote() {
userRemote = user2
userLocal = user1
}
if !userRemote.IsRemote() {
return nil, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
}
if userLocal.IsRemote() {
return nil, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
}
if userRemote.GetRemoteID() != rc.RemoteId {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
}
// ensure remote user is allowed to DM the local user
canSee, appErr := scs.app.UserCanSeeOtherUser(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id)
if appErr != nil {
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "cannot check user visibility for DM creation",
mlog.String("user_remote", userRemote.Id),
mlog.String("user_local", userLocal.Id),
mlog.String("channel_id", invite.ChannelId),
mlog.Err(appErr),
)
return nil, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
}
if !canSee {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
}
channel, appErr := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id, model.WithID(invite.ChannelId))
if appErr != nil {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
}
return channel, nil

View File

@ -4,6 +4,7 @@
package sharedchannel
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -24,6 +25,7 @@ var (
mockTypeChannel = mock.AnythingOfType("*model.Channel")
mockTypeString = mock.AnythingOfType("string")
mockTypeReqContext = mock.AnythingOfType("*request.Context")
mockTypeContext = mock.MatchedBy(func(ctx context.Context) bool { return true })
)
func TestOnReceiveChannelInvite(t *testing.T) {
@ -157,48 +159,78 @@ func TestOnReceiveChannelInvite(t *testing.T) {
assert.Equal(t, fmt.Sprintf("cannot make channel readonly `%s`: foo: bar, boom", invitation.ChannelId), err.Error())
})
t.Run("when invitation prescribes a direct channel, it does create a direct channel", func(t *testing.T) {
mockServer := &MockServerIface{}
logger := mlog.CreateConsoleTestLogger(t)
mockServer.On("Log").Return(logger)
mockApp := &MockAppIface{}
scs := &Service{
server: mockServer,
app: mockApp,
t.Run("DM channels", func(t *testing.T) {
var testRemoteID = model.NewId()
testCases := []struct {
desc string
user1 *model.User
user2 *model.User
canSee bool
expectSuccess bool
}{
{"valid users", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId()}, true, true},
{"swapped users", &model.User{Id: model.NewId()}, &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, true, true},
{"two remotes", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, true, false},
{"two locals", &model.User{Id: model.NewId()}, &model.User{Id: model.NewId()}, true, false},
{"can't see", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId()}, false, false},
{"invalid remoteid", &model.User{Id: model.NewId(), RemoteId: model.NewString("bogus")}, &model.User{Id: model.NewId()}, true, false},
}
mockStore := &mocks.Store{}
remoteCluster := &model.RemoteCluster{Name: "test3", CreatorId: model.NewId()}
invitation := channelInviteMsg{
ChannelId: model.NewId(),
TeamId: model.NewId(),
ReadOnly: false,
Type: model.ChannelTypeDirect,
DirectParticipantIDs: []string{model.NewId(), model.NewId()},
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
mockServer := &MockServerIface{}
logger := mlog.CreateConsoleTestLogger(t)
mockServer.On("Log").Return(logger)
mockApp := &MockAppIface{}
scs := &Service{
server: mockServer,
app: mockApp,
}
mockStore := &mocks.Store{}
remoteCluster := &model.RemoteCluster{Name: "test3", CreatorId: model.NewId(), RemoteId: testRemoteID}
invitation := channelInviteMsg{
ChannelId: model.NewId(),
TeamId: model.NewId(),
ReadOnly: false,
Type: model.ChannelTypeDirect,
DirectParticipantIDs: []string{tc.user1.Id, tc.user2.Id},
}
payload, err := json.Marshal(invitation)
require.NoError(t, err)
msg := model.RemoteClusterMsg{
Payload: payload,
}
mockChannelStore := mocks.ChannelStore{}
mockSharedChannelStore := mocks.SharedChannelStore{}
channel := &model.Channel{}
mockUserStore := mocks.UserStore{}
mockUserStore.On("Get", mockTypeContext, tc.user1.Id).
Return(tc.user1, nil)
mockUserStore.On("Get", mockTypeContext, tc.user2.Id).
Return(tc.user2, nil)
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("SharedChannel").Return(&mockSharedChannelStore)
mockStore.On("User").Return(&mockUserStore)
mockServer = scs.server.(*MockServerIface)
mockServer.On("GetStore").Return(mockStore)
mockApp.On("GetOrCreateDirectChannel", mockTypeReqContext, mockTypeString, mockTypeString, mock.AnythingOfType("model.ChannelOption")).
Return(channel, nil).Maybe()
mockApp.On("UserCanSeeOtherUser", mockTypeReqContext, mockTypeString, mockTypeString).Return(tc.canSee, nil).Maybe()
defer mockApp.AssertExpectations(t)
err = scs.onReceiveChannelInvite(msg, remoteCluster, nil)
require.Equal(t, tc.expectSuccess, err == nil)
})
}
payload, err := json.Marshal(invitation)
require.NoError(t, err)
msg := model.RemoteClusterMsg{
Payload: payload,
}
mockChannelStore := mocks.ChannelStore{}
mockSharedChannelStore := mocks.SharedChannelStore{}
channel := &model.Channel{}
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("SharedChannel").Return(&mockSharedChannelStore)
mockServer = scs.server.(*MockServerIface)
mockServer.On("GetStore").Return(mockStore)
mockApp.On("GetOrCreateDirectChannel", mock.AnythingOfType("*request.Context"), invitation.DirectParticipantIDs[0], invitation.DirectParticipantIDs[1], mock.AnythingOfType("model.ChannelOption")).Return(channel, nil)
defer mockApp.AssertExpectations(t)
err = scs.onReceiveChannelInvite(msg, remoteCluster, nil)
require.NoError(t, err)
})
}

View File

@ -558,6 +558,36 @@ func (_m *MockAppIface) UpdatePost(c request.CTX, post *model.Post, safeUpdate b
return r0, r1
}
// UserCanSeeOtherUser provides a mock function with given fields: c, userID, otherUserId
func (_m *MockAppIface) UserCanSeeOtherUser(c request.CTX, userID string, otherUserId string) (bool, *model.AppError) {
ret := _m.Called(c, userID, otherUserId)
if len(ret) == 0 {
panic("no return value specified for UserCanSeeOtherUser")
}
var r0 bool
var r1 *model.AppError
if rf, ok := ret.Get(0).(func(request.CTX, string, string) (bool, *model.AppError)); ok {
return rf(c, userID, otherUserId)
}
if rf, ok := ret.Get(0).(func(request.CTX, string, string) bool); ok {
r0 = rf(c, userID, otherUserId)
} else {
r0 = ret.Get(0).(bool)
}
if rf, ok := ret.Get(1).(func(request.CTX, string, string) *model.AppError); ok {
r1 = rf(c, userID, otherUserId)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// NewMockAppIface creates a new instance of MockAppIface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockAppIface(t interface {

View File

@ -55,6 +55,7 @@ type AppIface interface {
SendEphemeralPost(c request.CTX, userId string, post *model.Post) *model.Post
CreateChannelWithUser(c request.CTX, channel *model.Channel, userId string) (*model.Channel, *model.AppError)
GetOrCreateDirectChannel(c request.CTX, userId, otherUserId string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError)
UserCanSeeOtherUser(c request.CTX, userID string, otherUserId string) (bool, *model.AppError)
AddUserToChannel(c request.CTX, user *model.User, channel *model.Channel, skipTeamMemberIntegrityCheck bool) (*model.ChannelMember, *model.AppError)
AddUserToTeamByTeamId(c request.CTX, teamId string, user *model.User) *model.AppError
PermanentDeleteChannel(c request.CTX, channel *model.Channel) *model.AppError

View File

@ -20,6 +20,7 @@ import (
var (
ErrRemoteIDMismatch = errors.New("remoteID mismatch")
ErrChannelIDMismatch = errors.New("channelID mismatch")
ErrUserDMPermission = errors.New("users cannot DM each other")
)
func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.RemoteCluster, response *remotecluster.Response) error {
@ -224,20 +225,23 @@ func (scs *Service) upsertSyncUser(c request.CTX, user *model.User, channel *mod
}
}
// Add user to team. We do this here regardless of whether the user was
// Add user to team and channel. We do this here regardless of whether the user was
// just created or patched since there are three steps to adding a user
// (insert rec, add to team, add to channel) and any one could fail.
// Instead of undoing what succeeded on any failure we simply do all steps each
// time. AddUserToChannel & AddUserToTeamByTeamId do not error if user was already
// added and exit quickly.
if err := scs.app.AddUserToTeamByTeamId(request.EmptyContext(scs.server.Log()), channel.TeamId, userSaved); err != nil {
return nil, fmt.Errorf("error adding sync user to Team: %w", err)
// added and exit quickly. Not needed for DMs where teamId is empty.
if channel.TeamId != "" {
// add user to team
if err := scs.app.AddUserToTeamByTeamId(request.EmptyContext(scs.server.Log()), channel.TeamId, userSaved); err != nil {
return nil, fmt.Errorf("error adding sync user to Team: %w", err)
}
// add user to channel
if _, err := scs.app.AddUserToChannel(c, userSaved, channel, false); err != nil {
return nil, fmt.Errorf("error adding sync user to ChannelMembers: %w", err)
}
}
// add user to channel
if _, err := scs.app.AddUserToChannel(c, userSaved, channel, false); err != nil {
return nil, fmt.Errorf("error adding sync user to ChannelMembers: %w", err)
}
return userSaved, nil
}