MM-57867 Don't delete existing DM on invitation error (#27357)

* ensure channel invitations create new channels; don't delete pre-existing channels on failure cleanup

* update comment
This commit is contained in:
Doug Lauder 2024-06-19 10:20:46 -04:00 committed by GitHub
parent f41e8ad756
commit 8181a9ddff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 19 deletions

View File

@ -164,6 +164,7 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
// check if channel already exists
var channel *model.Channel
var created bool
_, err := scs.server.GetStore().Channel().Get(invite.ChannelId, true)
if err == nil {
// the channel already exists on this server; could be the remote is trying to re-share it (not allowed at this time).
@ -172,10 +173,29 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
}
// create new local channel to sync with the remote channel
if channel, err = scs.handleChannelCreation(invite, rc); err != nil {
if channel, created, err = scs.handleChannelCreation(invite, rc); err != nil {
return err
}
// sanity check to ensure the channel returned has the expected id. Otherwise sync will not work as expected and will fail
// silently.
if invite.ChannelId != channel.Id {
// as of this writing, this scenario should only be possible if the invite included a DM channel invitation with a
// combination of two user ids (one remote, one local) that already have a DM on this server. Very unlikely unless
// the remote is compromised AND has knowledge of the local user id.
// Another possibility would be an actual user ID collision between two servers, where the likelihood is
// infinitesimally small
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Channel invite failed - channel created/fetched with wrong id",
mlog.String("remote", rc.DisplayName),
mlog.String("channel_id", invite.ChannelId),
mlog.String("channel_type", invite.Type),
mlog.String("channel_name", invite.Name),
mlog.String("team_id", invite.TeamId),
mlog.Array("dm_partics", invite.DirectParticipantIDs),
)
return fmt.Errorf("cannot create shared channel (DM channel_id=%s): %w", invite.ChannelId, model.ErrChannelAlreadyExists)
}
// mark the newly created channel read-only if requested in the invite
if invite.ReadOnly {
if err := scs.makeChannelReadOnly(channel); err != nil {
@ -199,7 +219,9 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
if _, err := scs.server.GetStore().SharedChannel().Save(sharedChannel); err != nil {
// delete the newly created channel since we could not create a SharedChannel record for it
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
if created {
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
}
return fmt.Errorf("cannot create shared channel (channel_id=%s): %w", invite.ChannelId, err)
}
@ -217,14 +239,19 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
if _, err := scs.server.GetStore().SharedChannel().SaveRemote(sharedChannelRemote); err != nil {
// delete the newly created channel since we could not create a SharedChannelRemote record for it,
// and delete the newly created SharedChannel record as well.
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
if created {
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
}
scs.server.GetStore().SharedChannel().Delete(sharedChannel.ChannelId)
return fmt.Errorf("cannot create shared channel remote (channel_id=%s): %w", invite.ChannelId, err)
}
return nil
}
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
// handleChannelCreation creates a new channel to represent the remote channel in the invitation.
// For DMs there is a chance the channel already exists (shared, unshared, shared again) and the boolean
// determines if the channel was newly created (true=new)
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, bool, error) {
if invite.Type == model.ChannelTypeDirect {
return scs.createDirectChannel(invite, rc)
}
@ -244,25 +271,27 @@ func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.Rem
// check user perms?
channel, appErr := scs.app.CreateChannelWithUser(request.EmptyContext(scs.server.Log()), channelNew, rc.CreatorId)
if appErr != nil {
return nil, fmt.Errorf("cannot create channel `%s`: %w", invite.ChannelId, appErr)
return nil, false, fmt.Errorf("cannot create channel `%s`: %w", invite.ChannelId, appErr)
}
return channel, nil
return channel, true, nil
}
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
// createDirectChannel creates a DM channel, or fetches an existing channel, and returns the channel plus a boolean
// indicating if the channel is new.
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, bool, error) {
if len(invite.DirectParticipantIDs) != 2 {
return nil, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
return nil, false, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
}
user1, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[0])
if err != nil {
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
return nil, false, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
}
user2, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[1])
if err != nil {
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
return nil, false, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
}
// determine the remote user
@ -277,15 +306,15 @@ func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.Remot
}
if !userRemote.IsRemote() {
return nil, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
return nil, false, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
}
if userLocal.IsRemote() {
return nil, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
return nil, false, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
}
if userRemote.GetRemoteID() != rc.RemoteId {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
}
// ensure remote user is allowed to DM the local user
@ -297,16 +326,30 @@ func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.Remot
mlog.String("channel_id", invite.ChannelId),
mlog.Err(appErr),
)
return nil, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
return nil, false, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
}
if !canSee {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
}
// check if this DM already exists.
channelName := model.GetDMNameFromIds(userRemote.Id, userLocal.Id)
channelExists, err := scs.server.GetStore().Channel().GetByName("", channelName, true)
if err != nil && !isNotFoundError(err) {
return nil, false, fmt.Errorf("cannot check DM channel exists (%s): %w", channelName, err)
}
if channelExists != nil {
if channelExists.Id == invite.ChannelId {
return channelExists, false, nil
}
return nil, false, fmt.Errorf("cannot create direct channel `%s`: channel exists with wrong id", channelName)
}
// create the channel
channel, appErr := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id, model.WithID(invite.ChannelId))
if appErr != nil {
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
}
return channel, nil
return channel, true, nil
}

View File

@ -141,7 +141,9 @@ func TestOnReceiveChannelInvite(t *testing.T) {
Payload: payload,
}
mockChannelStore := mocks.ChannelStore{}
channel := &model.Channel{}
channel := &model.Channel{
Id: invitation.ChannelId,
}
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, &store.ErrNotFound{})
mockStore.On("Channel").Return(&mockChannelStore)
@ -204,7 +206,9 @@ func TestOnReceiveChannelInvite(t *testing.T) {
}
mockChannelStore := mocks.ChannelStore{}
mockSharedChannelStore := mocks.SharedChannelStore{}
channel := &model.Channel{}
channel := &model.Channel{
Id: invitation.ChannelId,
}
mockUserStore := mocks.UserStore{}
mockUserStore.On("Get", mockTypeContext, tc.user1.Id).
@ -213,6 +217,8 @@ func TestOnReceiveChannelInvite(t *testing.T) {
Return(tc.user2, nil)
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
mockChannelStore.On("GetByName", "", mockTypeString, true).Return(nil, &store.ErrNotFound{})
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
mockStore.On("Channel").Return(&mockChannelStore)