MM-57786 Fix Shared Channels plugin api (#26753)

* always ping on plugin registration; SharedChannel.IsValid allow no team for GM

* wait for services to start before ping

* ping plugin remotes synchronously on startup

* remove the waitForInterClusterServices stuff

* don't set remoteid when inviting remote to channel

* Update server/public/model/remote_cluster_test.go

Co-authored-by: Ibrahim Serdar Acikgoz <serdaracikgoz86@gmail.com>

* address review comments

---------

Co-authored-by: Mattermost Build <build@mattermost.com>
Co-authored-by: Ibrahim Serdar Acikgoz <serdaracikgoz86@gmail.com>
This commit is contained in:
Doug Lauder 2024-04-15 16:18:25 -04:00 committed by GitHub
parent 9e8f9a3715
commit 6aaabfb376
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 152 additions and 32 deletions

View File

@ -62,6 +62,15 @@ func (a *App) RegisterPluginForSharedChannels(opts model.RegisterPluginOpts) (re
mlog.String("remote_id", rcSaved.RemoteId),
)
// ping the plugin remote immediately if the service is running
// If the service is not available the ping will happen once the
// service starts. This is expected since plugins start before the
// service.
rcService, _ := a.GetRemoteClusterService()
if rcService != nil {
rcService.PingNow(rcSaved)
}
return rcSaved.RemoteId, nil
}

View File

@ -191,6 +191,10 @@ func (s sqlRemoteClusterStore) GetAll(filter model.RemoteClusterQueryFilter) ([]
query = query.Where(sq.Eq{"rc.PluginID": filter.PluginID})
}
if filter.OnlyPlugins {
query = query.Where(sq.NotEq{"rc.PluginID": ""})
}
if filter.RequireOptions != 0 {
query = query.Where(sq.NotEq{fmt.Sprintf("(rc.Options & %d)", filter.RequireOptions): 0})
}

View File

@ -31,7 +31,7 @@ func makeSiteURL() string {
return "www.example.com/" + model.NewId()
}
func testRemoteClusterSave(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterSave(t *testing.T, _ request.CTX, ss store.Store) {
t.Run("Save", func(t *testing.T) {
rc := &model.RemoteCluster{
Name: "some_remote",
@ -145,7 +145,7 @@ func testRemoteClusterSave(t *testing.T, rctx request.CTX, ss store.Store) {
})
}
func testRemoteClusterDelete(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterDelete(t *testing.T, _ request.CTX, ss store.Store) {
t.Run("Delete", func(t *testing.T) {
rc := &model.RemoteCluster{
Name: "shortlived_remote",
@ -167,7 +167,7 @@ func testRemoteClusterDelete(t *testing.T, rctx request.CTX, ss store.Store) {
})
}
func testRemoteClusterGet(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterGet(t *testing.T, _ request.CTX, ss store.Store) {
t.Run("Get", func(t *testing.T) {
rc := &model.RemoteCluster{
Name: "shortlived_remote_2",
@ -192,7 +192,7 @@ func testRemoteClusterGet(t *testing.T, rctx request.CTX, ss store.Store) {
})
}
func testRemoteClusterGetByPluginID(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterGetByPluginID(t *testing.T, _ request.CTX, ss store.Store) {
const pluginID = "com.acme.bogus.plugin"
t.Run("GetByPluginID", func(t *testing.T) {
@ -217,7 +217,7 @@ func testRemoteClusterGetByPluginID(t *testing.T, rctx request.CTX, ss store.Sto
})
}
func testRemoteClusterGetAll(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterGetAll(t *testing.T, _ request.CTX, ss store.Store) {
require.NoError(t, clearRemoteClusters(ss))
userId := model.NewId()
@ -230,11 +230,15 @@ func testRemoteClusterGetAll(t *testing.T, rctx request.CTX, ss store.Store) {
{Name: "another_online_remote", CreatorId: model.NewId(), SiteURL: makeSiteURL(), LastPingAt: now, Topics: ""},
{Name: "another_offline_remote", CreatorId: model.NewId(), SiteURL: makeSiteURL(), LastPingAt: pingLongAgo, Topics: " shared "},
{Name: "brand_new_offline_remote", CreatorId: userId, SiteURL: "", LastPingAt: 0, Topics: " bogus shared stuff "},
{Name: "offline_plugin_remote", CreatorId: model.NewId(), SiteURL: makeSiteURL(), PluginID: model.NewId(), LastPingAt: 0, Topics: " pluginshare "},
{Name: "online_plugin_remote", CreatorId: model.NewId(), SiteURL: makeSiteURL(), PluginID: model.NewId(), LastPingAt: now, Topics: " pluginshare "},
}
idsAll := make([]string, 0)
idsOnline := make([]string, 0)
idsShareTopic := make([]string, 0)
idsPlugin := make([]string, 0)
idsConfirmed := make([]string, 0)
for _, item := range data {
online := item.LastPingAt == now
@ -247,6 +251,12 @@ func testRemoteClusterGetAll(t *testing.T, rctx request.CTX, ss store.Store) {
if strings.Contains(saved.Topics, " shared ") {
idsShareTopic = append(idsShareTopic, saved.RemoteId)
}
if item.PluginID != "" {
idsPlugin = append(idsPlugin, saved.RemoteId)
}
if item.SiteURL != "" {
idsConfirmed = append(idsConfirmed, saved.RemoteId)
}
}
t.Run("GetAll", func(t *testing.T) {
@ -315,10 +325,28 @@ func testRemoteClusterGetAll(t *testing.T, rctx request.CTX, ss store.Store) {
remotes, err := ss.RemoteCluster().GetAll(filter)
require.NoError(t, err)
// make sure only confirmed returned
assert.Len(t, remotes, 4)
for _, rc := range remotes {
assert.NotEmpty(t, rc.SiteURL)
}
// make sure all confirmed returned
ids := getIds(remotes)
assert.ElementsMatch(t, ids, idsConfirmed)
})
t.Run("GetAll only plugins", func(t *testing.T) {
filter := model.RemoteClusterQueryFilter{
OnlyPlugins: true,
}
remotes, err := ss.RemoteCluster().GetAll(filter)
require.NoError(t, err)
// make sure only plugin remotes returned
for _, rc := range remotes {
assert.NotEmpty(t, rc.PluginID)
assert.True(t, rc.IsPlugin())
}
// make sure all the plugin remotes were returned.
ids := getIds(remotes)
assert.ElementsMatch(t, ids, idsPlugin)
})
}
@ -542,7 +570,7 @@ func getIds(remotes []*model.RemoteCluster) []string {
return ids
}
func testRemoteClusterGetByTopic(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterGetByTopic(t *testing.T, _ request.CTX, ss store.Store) {
require.NoError(t, clearRemoteClusters(ss))
rcData := []*model.RemoteCluster{
@ -587,7 +615,7 @@ func testRemoteClusterGetByTopic(t *testing.T, rctx request.CTX, ss store.Store)
}
}
func testRemoteClusterUpdateTopics(t *testing.T, rctx request.CTX, ss store.Store) {
func testRemoteClusterUpdateTopics(t *testing.T, _ request.CTX, ss store.Store) {
remoteId := model.NewId()
rc := &model.RemoteCluster{
DisplayName: "Blap Inc",

View File

@ -73,7 +73,7 @@ type mockApp struct {
pingCounts map[string]int
}
func newMockApp(t *testing.T, offlinePluginIDs []string) *mockApp {
func newMockApp(_ *testing.T, offlinePluginIDs []string) *mockApp {
return &mockApp{
offlinePluginIDs: offlinePluginIDs,
pingCounts: make(map[string]int),

View File

@ -34,6 +34,23 @@ func (rcs *Service) PingNow(rc *model.RemoteCluster) {
}
}
// pingAllNow emits a ping to all remotes immediately without waiting for next ping loop.
func (rcs *Service) pingAllNow(filter model.RemoteClusterQueryFilter) {
// get all remotes, including any previously offline.
remotes, err := rcs.server.GetStore().RemoteCluster().GetAll(filter)
if err != nil {
rcs.server.Log().Log(mlog.LvlRemoteClusterServiceError, "Ping all remote clusters failed (could not get list of remotes)", mlog.Err(err))
return
}
for _, rc := range remotes {
// filter out unconfirmed invites so we don't ping them without permission
if rc.IsConfirmed() {
rcs.PingNow(rc)
}
}
}
// pingLoop periodically sends a ping to all remote clusters.
func (rcs *Service) pingLoop(done <-chan struct{}) {
pingChan := make(chan *model.RemoteCluster, MaxConcurrentSends*2)
@ -53,24 +70,8 @@ func (rcs *Service) pingGenerator(pingChan chan *model.RemoteCluster, done <-cha
pingFreq := rcs.GetPingFreq()
start := time.Now()
// get all remotes, including any previously offline.
remotes, err := rcs.server.GetStore().RemoteCluster().GetAll(model.RemoteClusterQueryFilter{})
if err != nil {
rcs.server.Log().Log(mlog.LvlRemoteClusterServiceError, "Ping remote cluster failed (could not get list of remotes)", mlog.Err(err))
select {
case <-time.After(pingFreq):
continue
case <-done:
return
}
}
for _, rc := range remotes {
// filter out unconfirmed invites so we don't ping them without permission
if rc.IsConfirmed() {
pingChan <- rc
}
}
// ping all remotes, including any previously offline.
rcs.pingAllNow(model.RemoteClusterQueryFilter{})
// try to maintain frequency
elapsed := time.Since(start)

View File

@ -261,6 +261,10 @@ func (rcs *Service) resume() {
rcs.done = make(chan struct{})
if !disablePing {
// first ping all the plugin remotes immediately, synchronously.
rcs.pingAllNow(model.RemoteClusterQueryFilter{OnlyPlugins: true})
// start the async ping loop
rcs.pingLoop(rcs.done)
}

View File

@ -72,6 +72,37 @@ func (scs *Service) SendChannelInvite(channel *model.Channel, userId string, rc
msg := model.NewRemoteClusterMsg(TopicChannelInvite, json)
// onInvite is called after invite is sent, whether to a remote cluster or plugin.
onInvite := func(_ model.RemoteClusterMsg, rc *model.RemoteCluster, resp *remotecluster.Response, err error) {
if err != nil || !resp.IsSuccess() {
scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error sending channel invite for %s: %s", rc.DisplayName, combineErrors(err, resp.Err)))
return
}
scr := &model.SharedChannelRemote{
ChannelId: sc.ChannelId,
CreatorId: userId,
RemoteId: rc.RemoteId,
IsInviteAccepted: true,
IsInviteConfirmed: true,
LastPostCreateAt: model.GetMillis(),
LastPostUpdateAt: model.GetMillis(),
}
if _, err = scs.server.GetStore().SharedChannel().SaveRemote(scr); err != nil {
scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("Error confirming channel invite for %s: %v", rc.DisplayName, err))
return
}
scs.NotifyChannelChanged(sc.ChannelId)
scs.sendEphemeralPost(channel.Id, userId, fmt.Sprintf("`%s` has been added to channel.", rc.DisplayName))
}
if rc.IsPlugin() {
// for now plugins are considered fully invited automatically
// TODO: MM-57537 create plugin hook that passes invitation to plugins if BitflagOptionAutoInvited is not set
onInvite(msg, rc, &remotecluster.Response{Status: remotecluster.ResponseStatusOK}, nil)
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), remotecluster.SendTimeout)
defer cancel()

View File

@ -121,7 +121,7 @@ func (scs *Service) InviteRemoteToChannel(channelID, remoteID, userID string, sh
ChannelId: channelID,
CreatorId: userID,
Home: true,
RemoteId: remoteID,
RemoteId: "", // channel originates locally
}
if _, err = scs.ShareChannel(sc); err != nil {
return model.NewAppError("InviteRemoteToChannel", "api.command_share.share_channel.error",

View File

@ -219,7 +219,7 @@ func (scs *Service) upsertSyncUser(c request.CTX, user *model.User, channel *mod
return userSaved, nil
}
func (scs *Service) insertSyncUser(rctx request.CTX, user *model.User, channel *model.Channel, rc *model.RemoteCluster) (*model.User, error) {
func (scs *Service) insertSyncUser(rctx request.CTX, user *model.User, _ *model.Channel, rc *model.RemoteCluster) (*model.User, error) {
var err error
var userSaved *model.User
var suffix string
@ -270,7 +270,7 @@ func (scs *Service) insertSyncUser(rctx request.CTX, user *model.User, channel *
return nil, fmt.Errorf("error inserting sync user %s: %w", user.Id, err)
}
func (scs *Service) updateSyncUser(rctx request.CTX, patch *model.UserPatch, user *model.User, channel *model.Channel, rc *model.RemoteCluster) (*model.User, error) {
func (scs *Service) updateSyncUser(rctx request.CTX, patch *model.UserPatch, user *model.User, _ *model.Channel, rc *model.RemoteCluster) (*model.User, error) {
var err error
var update *model.UserUpdate
var suffix string

View File

@ -107,6 +107,10 @@ func (scs *Service) syncForRemote(task syncTask, rc *model.RemoteCluster) error
if scr, err = scs.server.GetStore().SharedChannel().SaveRemote(scr); err != nil {
return fmt.Errorf("cannot auto-create shared channel remote (channel_id=%s, remote_id=%s): %w", task.channelID, rc.RemoteId, err)
}
scs.server.Log().Log(mlog.LvlSharedChannelServiceDebug, "Auto-invited remote to channel (BitflagOptionAutoInvited)",
mlog.String("remote", rc.DisplayName),
mlog.String("channel_id", task.channelID),
)
} else if err != nil {
return err
}

View File

@ -6,7 +6,9 @@ package model
import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/base32"
"encoding/json"
"errors"
"io"
@ -80,8 +82,12 @@ func (rc *RemoteCluster) Auditable() map[string]interface{} {
func (rc *RemoteCluster) PreSave() {
if rc.RemoteId == "" {
if rc.PluginID != "" {
rc.RemoteId = newIDFromBytes([]byte(rc.PluginID))
} else {
rc.RemoteId = NewId()
}
}
if rc.DisplayName == "" {
rc.DisplayName = rc.Name
@ -120,6 +126,16 @@ func (rc *RemoteCluster) IsValid() *AppError {
return nil
}
func newIDFromBytes(b []byte) string {
hash := md5.New()
_, _ = hash.Write(b)
buf := hash.Sum(nil)
var encoding = base32.NewEncoding("ybndrfg8ejkmcpqxot1uwisza345h769").WithPadding(base32.NoPadding)
id := encoding.EncodeToString(buf)
return id[:26]
}
func (rc *RemoteCluster) IsOptionFlagSet(flag Bitmask) bool {
return rc.Options.IsBitSet(flag)
}
@ -385,5 +401,6 @@ type RemoteClusterQueryFilter struct {
CreatorId string
OnlyConfirmed bool
PluginID string
OnlyPlugins bool
RequireOptions Bitmask
}

View File

@ -136,3 +136,25 @@ func makeInvite(url string) RemoteClusterInvite {
Token: NewId(),
}
}
func TestNewIDFromBytes(t *testing.T) {
tests := []struct {
name string
ss string
}{
{name: "empty", ss: ""},
{name: "very short", ss: "x"},
{name: "normal", ss: "com.mattermost.msteams-sync"},
{name: "long", ss: "com.mattermost.msteams-synccom.mattermost.msteams-synccom.mattermost.msteams-synccom.mattermost.msteams-sync"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got1 := newIDFromBytes([]byte(tt.ss))
assert.True(t, IsValidId(got1), "not a valid id")
got2 := newIDFromBytes([]byte(tt.ss))
assert.Equal(t, got1, got2, "newIDFromBytes must generate same id for same inputs")
})
}
}

View File

@ -42,7 +42,7 @@ func (sc *SharedChannel) IsValid() *AppError {
return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.id.app_error", nil, "ChannelId="+sc.ChannelId, http.StatusBadRequest)
}
if sc.Type != ChannelTypeDirect && !IsValidId(sc.TeamId) {
if sc.Type != ChannelTypeDirect && sc.Type != ChannelTypeGroup && !IsValidId(sc.TeamId) {
return NewAppError("SharedChannel.IsValid", "model.channel.is_valid.id.app_error", nil, "TeamId="+sc.TeamId, http.StatusBadRequest)
}