mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-57867 Don't delete existing DM on invitation error (#27357)
* ensure channel invitations create new channels; don't delete pre-existing channels on failure cleanup * update comment
This commit is contained in:
parent
f41e8ad756
commit
8181a9ddff
@ -164,6 +164,7 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
|
||||
|
||||
// check if channel already exists
|
||||
var channel *model.Channel
|
||||
var created bool
|
||||
_, err := scs.server.GetStore().Channel().Get(invite.ChannelId, true)
|
||||
if err == nil {
|
||||
// the channel already exists on this server; could be the remote is trying to re-share it (not allowed at this time).
|
||||
@ -172,10 +173,29 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
|
||||
}
|
||||
|
||||
// create new local channel to sync with the remote channel
|
||||
if channel, err = scs.handleChannelCreation(invite, rc); err != nil {
|
||||
if channel, created, err = scs.handleChannelCreation(invite, rc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sanity check to ensure the channel returned has the expected id. Otherwise sync will not work as expected and will fail
|
||||
// silently.
|
||||
if invite.ChannelId != channel.Id {
|
||||
// as of this writing, this scenario should only be possible if the invite included a DM channel invitation with a
|
||||
// combination of two user ids (one remote, one local) that already have a DM on this server. Very unlikely unless
|
||||
// the remote is compromised AND has knowledge of the local user id.
|
||||
// Another possibility would be an actual user ID collision between two servers, where the likelihood is
|
||||
// infinitesimally small
|
||||
scs.server.Log().Log(mlog.LvlSharedChannelServiceError, "Channel invite failed - channel created/fetched with wrong id",
|
||||
mlog.String("remote", rc.DisplayName),
|
||||
mlog.String("channel_id", invite.ChannelId),
|
||||
mlog.String("channel_type", invite.Type),
|
||||
mlog.String("channel_name", invite.Name),
|
||||
mlog.String("team_id", invite.TeamId),
|
||||
mlog.Array("dm_partics", invite.DirectParticipantIDs),
|
||||
)
|
||||
return fmt.Errorf("cannot create shared channel (DM channel_id=%s): %w", invite.ChannelId, model.ErrChannelAlreadyExists)
|
||||
}
|
||||
|
||||
// mark the newly created channel read-only if requested in the invite
|
||||
if invite.ReadOnly {
|
||||
if err := scs.makeChannelReadOnly(channel); err != nil {
|
||||
@ -199,7 +219,9 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
|
||||
|
||||
if _, err := scs.server.GetStore().SharedChannel().Save(sharedChannel); err != nil {
|
||||
// delete the newly created channel since we could not create a SharedChannel record for it
|
||||
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
|
||||
if created {
|
||||
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
|
||||
}
|
||||
return fmt.Errorf("cannot create shared channel (channel_id=%s): %w", invite.ChannelId, err)
|
||||
}
|
||||
|
||||
@ -217,14 +239,19 @@ func (scs *Service) onReceiveChannelInvite(msg model.RemoteClusterMsg, rc *model
|
||||
if _, err := scs.server.GetStore().SharedChannel().SaveRemote(sharedChannelRemote); err != nil {
|
||||
// delete the newly created channel since we could not create a SharedChannelRemote record for it,
|
||||
// and delete the newly created SharedChannel record as well.
|
||||
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
|
||||
if created {
|
||||
scs.app.PermanentDeleteChannel(request.EmptyContext(scs.server.Log()), channel)
|
||||
}
|
||||
scs.server.GetStore().SharedChannel().Delete(sharedChannel.ChannelId)
|
||||
return fmt.Errorf("cannot create shared channel remote (channel_id=%s): %w", invite.ChannelId, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
|
||||
// handleChannelCreation creates a new channel to represent the remote channel in the invitation.
|
||||
// For DMs there is a chance the channel already exists (shared, unshared, shared again) and the boolean
|
||||
// determines if the channel was newly created (true=new)
|
||||
func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, bool, error) {
|
||||
if invite.Type == model.ChannelTypeDirect {
|
||||
return scs.createDirectChannel(invite, rc)
|
||||
}
|
||||
@ -244,25 +271,27 @@ func (scs *Service) handleChannelCreation(invite channelInviteMsg, rc *model.Rem
|
||||
// check user perms?
|
||||
channel, appErr := scs.app.CreateChannelWithUser(request.EmptyContext(scs.server.Log()), channelNew, rc.CreatorId)
|
||||
if appErr != nil {
|
||||
return nil, fmt.Errorf("cannot create channel `%s`: %w", invite.ChannelId, appErr)
|
||||
return nil, false, fmt.Errorf("cannot create channel `%s`: %w", invite.ChannelId, appErr)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
return channel, true, nil
|
||||
}
|
||||
|
||||
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, error) {
|
||||
// createDirectChannel creates a DM channel, or fetches an existing channel, and returns the channel plus a boolean
|
||||
// indicating if the channel is new.
|
||||
func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.RemoteCluster) (*model.Channel, bool, error) {
|
||||
if len(invite.DirectParticipantIDs) != 2 {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s` insufficient participant count `%d`", invite.ChannelId, len(invite.DirectParticipantIDs))
|
||||
}
|
||||
|
||||
user1, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s` cannot fetch user1 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[0], err)
|
||||
}
|
||||
|
||||
user2, err := scs.server.GetStore().User().Get(context.TODO(), invite.DirectParticipantIDs[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s` cannot fetch user2 (%s): %w", invite.ChannelId, invite.DirectParticipantIDs[1], err)
|
||||
}
|
||||
|
||||
// determine the remote user
|
||||
@ -277,15 +306,15 @@ func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.Remot
|
||||
}
|
||||
|
||||
if !userRemote.IsRemote() {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s` remote user is not remote (%s)", invite.ChannelId, userRemote.Id)
|
||||
}
|
||||
|
||||
if userLocal.IsRemote() {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s` local user is not local (%s)", invite.ChannelId, userLocal.Id)
|
||||
}
|
||||
|
||||
if userRemote.GetRemoteID() != rc.RemoteId {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrRemoteIDMismatch)
|
||||
}
|
||||
|
||||
// ensure remote user is allowed to DM the local user
|
||||
@ -297,16 +326,30 @@ func (scs *Service) createDirectChannel(invite channelInviteMsg, rc *model.Remot
|
||||
mlog.String("channel_id", invite.ChannelId),
|
||||
mlog.Err(appErr),
|
||||
)
|
||||
return nil, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
|
||||
return nil, false, fmt.Errorf("cannot check user visibility for DM (%s) creation: %w", invite.ChannelId, appErr)
|
||||
}
|
||||
if !canSee {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, ErrUserDMPermission)
|
||||
}
|
||||
|
||||
// check if this DM already exists.
|
||||
channelName := model.GetDMNameFromIds(userRemote.Id, userLocal.Id)
|
||||
channelExists, err := scs.server.GetStore().Channel().GetByName("", channelName, true)
|
||||
if err != nil && !isNotFoundError(err) {
|
||||
return nil, false, fmt.Errorf("cannot check DM channel exists (%s): %w", channelName, err)
|
||||
}
|
||||
if channelExists != nil {
|
||||
if channelExists.Id == invite.ChannelId {
|
||||
return channelExists, false, nil
|
||||
}
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s`: channel exists with wrong id", channelName)
|
||||
}
|
||||
|
||||
// create the channel
|
||||
channel, appErr := scs.app.GetOrCreateDirectChannel(request.EmptyContext(scs.server.Log()), userRemote.Id, userLocal.Id, model.WithID(invite.ChannelId))
|
||||
if appErr != nil {
|
||||
return nil, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
|
||||
return nil, false, fmt.Errorf("cannot create direct channel `%s`: %w", invite.ChannelId, appErr)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
return channel, true, nil
|
||||
}
|
||||
|
@ -141,7 +141,9 @@ func TestOnReceiveChannelInvite(t *testing.T) {
|
||||
Payload: payload,
|
||||
}
|
||||
mockChannelStore := mocks.ChannelStore{}
|
||||
channel := &model.Channel{}
|
||||
channel := &model.Channel{
|
||||
Id: invitation.ChannelId,
|
||||
}
|
||||
|
||||
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, &store.ErrNotFound{})
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
@ -204,7 +206,9 @@ func TestOnReceiveChannelInvite(t *testing.T) {
|
||||
}
|
||||
mockChannelStore := mocks.ChannelStore{}
|
||||
mockSharedChannelStore := mocks.SharedChannelStore{}
|
||||
channel := &model.Channel{}
|
||||
channel := &model.Channel{
|
||||
Id: invitation.ChannelId,
|
||||
}
|
||||
|
||||
mockUserStore := mocks.UserStore{}
|
||||
mockUserStore.On("Get", mockTypeContext, tc.user1.Id).
|
||||
@ -213,6 +217,8 @@ func TestOnReceiveChannelInvite(t *testing.T) {
|
||||
Return(tc.user2, nil)
|
||||
|
||||
mockChannelStore.On("Get", invitation.ChannelId, true).Return(nil, errors.New("boom"))
|
||||
mockChannelStore.On("GetByName", "", mockTypeString, true).Return(nil, &store.ErrNotFound{})
|
||||
|
||||
mockSharedChannelStore.On("Save", mock.Anything).Return(nil, nil)
|
||||
mockSharedChannelStore.On("SaveRemote", mock.Anything).Return(nil, nil)
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
|
Loading…
Reference in New Issue
Block a user