diff --git a/server/channels/api4/channel_bookmark_test.go b/server/channels/api4/channel_bookmark_test.go index 1769d55020..7505240329 100644 --- a/server/channels/api4/channel_bookmark_test.go +++ b/server/channels/api4/channel_bookmark_test.go @@ -16,6 +16,7 @@ import ( ) func TestCreateChannelBookmark(t *testing.T) { + t.Skip("MM-57312") os.Setenv("MM_FEATUREFLAGS_ChannelBookmarks", "true") defer os.Unsetenv("MM_FEATUREFLAGS_ChannelBookmarks") diff --git a/server/channels/app/plugin_db_driver.go b/server/channels/app/plugin_db_driver.go index 56ed0ceaa2..4ab4e0c00e 100644 --- a/server/channels/app/plugin_db_driver.go +++ b/server/channels/app/plugin_db_driver.go @@ -12,6 +12,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/shared/mlog" ) // DriverImpl implements the plugin.Driver interface on the server-side. @@ -21,7 +22,7 @@ import ( type DriverImpl struct { s *Server connMut sync.RWMutex - connMap map[string]*sql.Conn + connMap map[string]*connMeta txMut sync.Mutex txMap map[string]driver.Tx stMut sync.RWMutex @@ -30,10 +31,15 @@ type DriverImpl struct { rowsMap map[string]driver.Rows } +type connMeta struct { + pluginID string + conn *sql.Conn +} + func NewDriverImpl(s *Server) *DriverImpl { return &DriverImpl{ s: s, - connMap: make(map[string]*sql.Conn), + connMap: make(map[string]*connMeta), txMap: make(map[string]driver.Tx), stMap: make(map[string]driver.Stmt), rowsMap: make(map[string]driver.Rows), @@ -41,6 +47,14 @@ func NewDriverImpl(s *Server) *DriverImpl { } func (d *DriverImpl) Conn(isMaster bool) (string, error) { + return d.conn(isMaster, "") +} + +func (d *DriverImpl) ConnWithPluginID(isMaster bool, pluginID string) (string, error) { + return d.conn(isMaster, pluginID) +} + +func (d *DriverImpl) conn(isMaster bool, pluginID string) (string, error) { dbFunc := d.s.Platform().Store.GetInternalMasterDB if !isMaster { dbFunc = d.s.Platform().Store.GetInternalReplicaDB @@ -54,7 +68,7 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) { } connID := model.NewId() d.connMut.Lock() - d.connMap[connID] = conn + d.connMap[connID] = &connMeta{pluginID: pluginID, conn: conn} d.connMut.Unlock() return connID, nil } @@ -70,13 +84,13 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) { func (d *DriverImpl) ConnPing(connID string) error { d.connMut.RLock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] d.connMut.RUnlock() if !ok { return driver.ErrBadConn } - return conn.Raw(func(innerConn any) error { + return entry.conn.Raw(func(innerConn any) error { return innerConn.(driver.Pinger).Ping(context.Background()) }) } @@ -84,13 +98,13 @@ func (d *DriverImpl) ConnPing(connID string) error { func (d *DriverImpl) ConnQuery(connID, q string, args []driver.NamedValue) (_ string, err error) { var rows driver.Rows d.connMut.RLock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] d.connMut.RUnlock() if !ok { return "", driver.ErrBadConn } - err = conn.Raw(func(innerConn any) error { + err = entry.conn.Raw(func(innerConn any) error { rows, err = innerConn.(driver.QueryerContext).QueryContext(context.Background(), q, args) return err }) @@ -110,13 +124,13 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu var res driver.Result var ret plugin.ResultContainer d.connMut.RLock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] d.connMut.RUnlock() if !ok { return ret, driver.ErrBadConn } - err = conn.Raw(func(innerConn any) error { + err = entry.conn.Raw(func(innerConn any) error { res, err = innerConn.(driver.ExecerContext).ExecContext(context.Background(), q, args) return err }) @@ -132,7 +146,7 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu func (d *DriverImpl) ConnClose(connID string) error { d.connMut.Lock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] if !ok { d.connMut.Unlock() return driver.ErrBadConn @@ -140,19 +154,35 @@ func (d *DriverImpl) ConnClose(connID string) error { delete(d.connMap, connID) d.connMut.Unlock() - return conn.Close() + return entry.conn.Close() +} + +// ShutdownConns will close any remaining connections for the given pluginID +// if they weren't already closed by the plugin itself. +func (d *DriverImpl) ShutdownConns(pluginID string) { + d.connMut.Lock() + defer d.connMut.Unlock() + for connID, entry := range d.connMap { + if entry.pluginID == pluginID { + err := entry.conn.Close() + if err != nil { + d.s.Log().Error("Error while closing DB connection", mlog.Err(err), mlog.String("pluginID", pluginID)) + } + delete(d.connMap, connID) + } + } } func (d *DriverImpl) Tx(connID string, opts driver.TxOptions) (_ string, err error) { var tx driver.Tx d.connMut.RLock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] d.connMut.RUnlock() if !ok { return "", driver.ErrBadConn } - err = conn.Raw(func(innerConn any) error { + err = entry.conn.Raw(func(innerConn any) error { tx, err = innerConn.(driver.ConnBeginTx).BeginTx(context.Background(), opts) return err }) @@ -188,13 +218,13 @@ func (d *DriverImpl) TxRollback(txID string) error { func (d *DriverImpl) Stmt(connID, q string) (_ string, err error) { var stmt driver.Stmt d.connMut.RLock() - conn, ok := d.connMap[connID] + entry, ok := d.connMap[connID] d.connMut.RUnlock() if !ok { return "", driver.ErrBadConn } - err = conn.Raw(func(innerConn any) error { + err = entry.conn.Raw(func(innerConn any) error { stmt, err = innerConn.(driver.Conn).Prepare(q) return err }) diff --git a/server/channels/app/plugin_db_driver_test.go b/server/channels/app/plugin_db_driver_test.go index a2c428fb6f..6f9c8f0d61 100644 --- a/server/channels/app/plugin_db_driver_test.go +++ b/server/channels/app/plugin_db_driver_test.go @@ -19,3 +19,22 @@ func TestConnCreateTimeout(t *testing.T) { _, err := d.Conn(true) require.Error(t, err) } + +func TestShutdownPluginConns(t *testing.T) { + th := Setup(t) + defer th.TearDown() + + d := NewDriverImpl(th.Server) + _, err := d.ConnWithPluginID(true, "plugin1") + require.NoError(t, err) + _, err = d.ConnWithPluginID(true, "plugin2") + require.NoError(t, err) + _, err = d.ConnWithPluginID(true, "plugin1") + require.NoError(t, err) + + require.Len(t, d.connMap, 3) + d.ShutdownConns("plugin1") + require.Len(t, d.connMap, 1) + d.ShutdownConns("plugin2") + require.Len(t, d.connMap, 0) +} diff --git a/server/public/plugin/driver.go b/server/public/plugin/driver.go index 68eebeea1c..ba5c5d9aca 100644 --- a/server/public/plugin/driver.go +++ b/server/public/plugin/driver.go @@ -62,3 +62,14 @@ type Driver interface { // ResetSession(ctx context.Context) error // IsValid() bool } + +// AppDriver is an extension of the Driver interface to capture non-RPC APIs. +type AppDriver interface { + Driver + + // ConnWithPluginID is only used by the server, and isn't exposed via the RPC API. + ConnWithPluginID(isMaster bool, pluginID string) (string, error) + // This is an extra method needed to shutdown connections + // after a plugin shuts down. + ShutdownConns(pluginID string) +} diff --git a/server/public/plugin/environment.go b/server/public/plugin/environment.go index ccca0466d2..21054ad857 100644 --- a/server/public/plugin/environment.go +++ b/server/public/plugin/environment.go @@ -54,7 +54,7 @@ type Environment struct { logger *mlog.Logger metrics metricsInterface newAPIImpl apiImplCreatorFunc - dbDriver Driver + dbDriver AppDriver pluginDir string webappPluginDir string prepackagedPlugins []*PrepackagedPlugin @@ -64,7 +64,7 @@ type Environment struct { func NewEnvironment( newAPIImpl apiImplCreatorFunc, - dbDriver Driver, + dbDriver AppDriver, pluginDir string, webappPluginDir string, logger *mlog.Logger, diff --git a/server/public/plugin/supervisor.go b/server/public/plugin/supervisor.go index 3ab7add785..48ada556b2 100644 --- a/server/public/plugin/supervisor.go +++ b/server/public/plugin/supervisor.go @@ -24,6 +24,8 @@ import ( type supervisor struct { lock sync.RWMutex + pluginID string + appDriver AppDriver client *plugin.Client hooks Hooks implemented [TotalHooksID]bool @@ -31,6 +33,15 @@ type supervisor struct { isReattached bool } +type driverForPlugin struct { + AppDriver + pluginID string +} + +func (d *driverForPlugin) Conn(isMaster bool) (string, error) { + return d.AppDriver.ConnWithPluginID(isMaster, d.pluginID) +} + func WithExecutableFromManifest(pluginInfo *model.BundleInfo) func(*supervisor, *plugin.ClientConfig) error { return func(_ *supervisor, clientConfig *plugin.ClientConfig) error { executable := pluginInfo.Manifest.GetExecutableForRuntime(runtime.GOOS, runtime.GOARCH) @@ -74,8 +85,14 @@ func WithReattachConfig(pluginReattachConfig *model.PluginReattachConfig) func(* } } -func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) { - sup := supervisor{} +func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver AppDriver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) { + sup := supervisor{ + pluginID: pluginInfo.Manifest.Id, + } + if driver != nil { + sup.appDriver = &driverForPlugin{AppDriver: driver, pluginID: pluginInfo.Manifest.Id} + } + defer func() { if retErr != nil { sup.Shutdown() @@ -92,7 +109,7 @@ func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, par pluginMap := map[string]plugin.Plugin{ "hooks": &hooksPlugin{ log: wrappedLogger, - driverImpl: driver, + driverImpl: sup.appDriver, apiImpl: &apiTimerLayer{pluginInfo.Manifest.Id, apiImpl, metrics}, }, } @@ -166,8 +183,12 @@ func (sup *supervisor) Shutdown() { } // Wait for API RPC server and DB RPC server to exit. + // And then shutdown conns. if sup.hooksClient != nil { sup.hooksClient.doneWg.Wait() + if sup.appDriver != nil { + sup.appDriver.ShutdownConns(sup.pluginID) + } } }