mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-57873 Check user visibility when accepting channel invites for DMs (#27331)
* check user visibility when accepting channel invites for DMs * stronger visibility checking for DM users * check for correct remoteid for remote user in DM invite * fix unit test --------- Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
parent
9733694854
commit
bf8ddb4bdc
@ -226,7 +226,7 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
|
||||
|
||||
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
|
||||
if invite.Type == model.ChannelTypeDirect {
|
||||
return scs.createDirectChannel(invite)
|
||||
return scs.createDirectChannel(invite, rc)
|
||||
}
|
||||
|
||||
channelNew := &model.Channel{
|
||||
@ -250,14 +250,62 @@ func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.Rem
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func (scs *Service) createDirectChannel(invite channelInviteMsg) (*model.Channel, error) {
|
||||
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
|
||||
if len(invite.DirectParticipantIDs) != 2 {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
|
||||
}
|
||||
|
||||
channel, err := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), invite.DirectParticipantIDs[0], invite.DirectParticipantIDs[1], model.WithID(invite.ChannelId))
|
||||
user1, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, err)
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
|
||||
}
|
||||
|
||||
user2, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
|
||||
}
|
||||
|
||||
// determine the remote user
|
||||
// - if both are remote then the DM channel does not belong on this server
|
||||
// - if neither are remote then the DM channel should not be created via sync message
|
||||
// - if only one is remote then we check visibility relative to that user
|
||||
userRemote := user1
|
||||
userLocal := user2
|
||||
if !userRemote.IsRemote() {
|
||||
userRemote = user2
|
||||
userLocal = user1
|
||||
}
|
||||
|
||||
if !userRemote.IsRemote() {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
|
||||
}
|
||||
|
||||
if userLocal.IsRemote() {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
|
||||
}
|
||||
|
||||
if userRemote.GetRemoteID() != rc.RemoteId {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
|
||||
}
|
||||
|
||||
// ensure remote user is allowed to DM the local user
|
||||
canSee, appErr := scs.app.UserCanSeeOtherUser(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id)
|
||||
if appErr != nil {
|
||||
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "cannot check user visibility for DM creation",
|
||||
mlog.String("user_remote", userRemote.Id),
|
||||
mlog.String("user_local", userLocal.Id),
|
||||
mlog.String("channel_id", invite.ChannelId),
|
||||
mlog.Err(appErr),
|
||||
)
|
||||
return nil, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
|
||||
}
|
||||
if !canSee {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
|
||||
}
|
||||
|
||||
channel, appErr := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id, model.WithID(invite.ChannelId))
|
||||
if appErr != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
|
@ -4,6 +4,7 @@
|
||||
package sharedchannel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -24,6 +25,7 @@ var (
|
||||
mockTypeChannel = mock.AnythingOfType("*model.Channel")
|
||||
mockTypeString = mock.AnythingOfType("string")
|
||||
mockTypeReqContext = mock.AnythingOfType("*request.Context")
|
||||
mockTypeContext = mock.MatchedBy(func(ctx context.Context) bool { return true })
|
||||
)
|
||||
|
||||
func TestOnReceiveChannelInvite(t *testing.T) {
|
||||
@ -157,48 +159,78 @@ func TestOnReceiveChannelInvite(t *testing.T) {
|
||||
assert.Equal(t, fmt.Sprintf("cannot make channel readonly `%s`: foo: bar, boom", invitation.ChannelId), err.Error())
|
||||
})
|
||||
|
||||
t.Run("when invitation prescribes a direct channel, it does create a direct channel", func(t *testing.T) {
|
||||
mockServer := &MockServerIface{}
|
||||
logger := mlog.CreateConsoleTestLogger(t)
|
||||
mockServer.On("Log").Return(logger)
|
||||
mockApp := &MockAppIface{}
|
||||
scs := &Service{
|
||||
server: mockServer,
|
||||
app: mockApp,
|
||||
t.Run("DM channels", func(t *testing.T) {
|
||||
var testRemoteID = model.NewId()
|
||||
testCases := []struct {
|
||||
desc string
|
||||
user1 *model.User
|
||||
user2 *model.User
|
||||
canSee bool
|
||||
expectSuccess bool
|
||||
}{
|
||||
{"valid users", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId()}, true, true},
|
||||
{"swapped users", &model.User{Id: model.NewId()}, &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, true, true},
|
||||
{"two remotes", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, true, false},
|
||||
{"two locals", &model.User{Id: model.NewId()}, &model.User{Id: model.NewId()}, true, false},
|
||||
{"can't see", &model.User{Id: model.NewId(), RemoteId: &testRemoteID}, &model.User{Id: model.NewId()}, false, false},
|
||||
{"invalid remoteid", &model.User{Id: model.NewId(), RemoteId: model.NewString("bogus")}, &model.User{Id: model.NewId()}, true, false},
|
||||
}
|
||||
|
||||
mockStore := &mocks.Store{}
|
||||
remoteCluster := &model.RemoteCluster{Name: "test3", CreatorId: model.NewId()}
|
||||
invitation := channelInviteMsg{
|
||||
ChannelId: model.NewId(),
|
||||
TeamId: model.NewId(),
|
||||
ReadOnly: false,
|
||||
Type: model.ChannelTypeDirect,
|
||||
DirectParticipantIDs: []string{model.NewId(), model.NewId()},
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
mockServer := &MockServerIface{}
|
||||
logger := mlog.CreateConsoleTestLogger(t)
|
||||
mockServer.On("Log").Return(logger)
|
||||
mockApp := &MockAppIface{}
|
||||
scs := &Service{
|
||||
server: mockServer,
|
||||
app: mockApp,
|
||||
}
|
||||
|
||||
mockStore := &mocks.Store{}
|
||||
remoteCluster := &model.RemoteCluster{Name: "test3", CreatorId: model.NewId(), RemoteId: testRemoteID}
|
||||
invitation := channelInviteMsg{
|
||||
ChannelId: model.NewId(),
|
||||
TeamId: model.NewId(),
|
||||
ReadOnly: false,
|
||||
Type: model.ChannelTypeDirect,
|
||||
DirectParticipantIDs: []string{tc.user1.Id, tc.user2.Id},
|
||||
}
|
||||
payload, err := json.Marshal(invitation)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := model.RemoteClusterMsg{
|
||||
Payload: payload,
|
||||
}
|
||||
mockChannelStore := mocks.ChannelStore{}
|
||||
mockSharedChannelStore := mocks.SharedChannelStore{}
|
||||
channel := &model.Channel{}
|
||||
|
||||
mockUserStore := mocks.UserStore{}
|
||||
mockUserStore.On("Get", mockTypeContext, tc.user1.Id).
|
||||
Return(tc.user1, nil)
|
||||
mockUserStore.On("Get", mockTypeContext, tc.user2.Id).
|
||||
Return(tc.user2, nil)
|
||||
|
||||
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
|
||||
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
|
||||
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
mockStore.On("SharedChannel").Return(&mockSharedChannelStore)
|
||||
mockStore.On("User").Return(&mockUserStore)
|
||||
|
||||
mockServer = scs.server.(*MockServerIface)
|
||||
mockServer.On("GetStore").Return(mockStore)
|
||||
|
||||
mockApp.On("GetOrCreateDirectChannel", mockTypeReqContext, mockTypeString, mockTypeString, mock.AnythingOfType("model.ChannelOption")).
|
||||
Return(channel, nil).Maybe()
|
||||
mockApp.On("UserCanSeeOtherUser", mockTypeReqContext, mockTypeString, mockTypeString).Return(tc.canSee, nil).Maybe()
|
||||
|
||||
defer mockApp.AssertExpectations(t)
|
||||
|
||||
err = scs.onReceiveChannelInvite(msg, remoteCluster, nil)
|
||||
require.Equal(t, tc.expectSuccess, err == nil)
|
||||
})
|
||||
}
|
||||
payload, err := json.Marshal(invitation)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := model.RemoteClusterMsg{
|
||||
Payload: payload,
|
||||
}
|
||||
mockChannelStore := mocks.ChannelStore{}
|
||||
mockSharedChannelStore := mocks.SharedChannelStore{}
|
||||
channel := &model.Channel{}
|
||||
|
||||
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
|
||||
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
|
||||
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
mockStore.On("SharedChannel").Return(&mockSharedChannelStore)
|
||||
|
||||
mockServer = scs.server.(*MockServerIface)
|
||||
mockServer.On("GetStore").Return(mockStore)
|
||||
|
||||
mockApp.On("GetOrCreateDirectChannel", mock.AnythingOfType("*request.Context"), invitation.DirectParticipantIDs[0], invitation.DirectParticipantIDs[1], mock.AnythingOfType("model.ChannelOption")).Return(channel, nil)
|
||||
defer mockApp.AssertExpectations(t)
|
||||
|
||||
err = scs.onReceiveChannelInvite(msg, remoteCluster, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -558,6 +558,36 @@ func (_m *MockAppIface) UpdatePost(c request.CTX, post *model.Post, safeUpdate b
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// UserCanSeeOtherUser provides a mock function with given fields: c, userID, otherUserId
|
||||
func (_m *MockAppIface) UserCanSeeOtherUser(c request.CTX, userID string, otherUserId string) (bool, *model.AppError) {
|
||||
ret := _m.Called(c, userID, otherUserId)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UserCanSeeOtherUser")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
var r1 *model.AppError
|
||||
if rf, ok := ret.Get(0).(func(request.CTX, string, string) (bool, *model.AppError)); ok {
|
||||
return rf(c, userID, otherUserId)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(request.CTX, string, string) bool); ok {
|
||||
r0 = rf(c, userID, otherUserId)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(request.CTX, string, string) *model.AppError); ok {
|
||||
r1 = rf(c, userID, otherUserId)
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).(*model.AppError)
|
||||
}
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewMockAppIface creates a new instance of MockAppIface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockAppIface(t interface {
|
||||
|
@ -55,6 +55,7 @@ type AppIface interface {
|
||||
SendEphemeralPost(c request.CTX, userId string, post *model.Post) *model.Post
|
||||
CreateChannelWithUser(c request.CTX, channel *model.Channel, userId string) (*model.Channel, *model.AppError)
|
||||
GetOrCreateDirectChannel(c request.CTX, userId, otherUserId string, channelOptions ...model.ChannelOption) (*model.Channel, *model.AppError)
|
||||
UserCanSeeOtherUser(c request.CTX, userID string, otherUserId string) (bool, *model.AppError)
|
||||
AddUserToChannel(c request.CTX, user *model.User, channel *model.Channel, skipTeamMemberIntegrityCheck bool) (*model.ChannelMember, *model.AppError)
|
||||
AddUserToTeamByTeamId(c request.CTX, teamId string, user *model.User) *model.AppError
|
||||
PermanentDeleteChannel(c request.CTX, channel *model.Channel) *model.AppError
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
var (
|
||||
ErrRemoteIDMismatch = errors.New("remoteID mismatch")
|
||||
ErrChannelIDMismatch = errors.New("channelID mismatch")
|
||||
ErrUserDMPermission = errors.New("users cannot DM each other")
|
||||
)
|
||||
|
||||
func (scs *Service) onReceiveSyncMessage(msg model.RemoteClusterMsg, rc *model.RemoteCluster, response *remotecluster.Response) error {
|
||||
@ -224,20 +225,23 @@ func (scs *Service) upsertSyncUser(c request.CTX, user *model.User, channel *mod
|
||||
}
|
||||
}
|
||||
|
||||
// Add user to team. We do this here regardless of whether the user was
|
||||
// Add user to team and channel. We do this here regardless of whether the user was
|
||||
// just created or patched since there are three steps to adding a user
|
||||
// (insert rec, add to team, add to channel) and any one could fail.
|
||||
// Instead of undoing what succeeded on any failure we simply do all steps each
|
||||
// time. AddUserToChannel & AddUserToTeamByTeamId do not error if user was already
|
||||
// added and exit quickly.
|
||||
if err := scs.app.AddUserToTeamByTeamId(request.EmptyContext(scs.server.Log()), channel.TeamId, userSaved); err != nil {
|
||||
return nil, fmt.Errorf("error adding sync user to Team: %w", err)
|
||||
// added and exit quickly. Not needed for DMs where teamId is empty.
|
||||
if channel.TeamId != "" {
|
||||
// add user to team
|
||||
if err := scs.app.AddUserToTeamByTeamId(request.EmptyContext(scs.server.Log()), channel.TeamId, userSaved); err != nil {
|
||||
return nil, fmt.Errorf("error adding sync user to Team: %w", err)
|
||||
}
|
||||
// add user to channel
|
||||
if _, err := scs.app.AddUserToChannel(c, userSaved, channel, false); err != nil {
|
||||
return nil, fmt.Errorf("error adding sync user to ChannelMembers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// add user to channel
|
||||
if _, err := scs.app.AddUserToChannel(c, userSaved, channel, false); err != nil {
|
||||
return nil, fmt.Errorf("error adding sync user to ChannelMembers: %w", err)
|
||||
}
|
||||
return userSaved, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user