MM-56402: Introduce a pluginID to track RPC DB connections (#26424)

Previously, we relied on the plugin to close the DB connections
on shutdown. While this keeps the code simple, there is no guarantee
that the plugin author will remember to close the DB.

In that case, it's better to track the connections from the server side
and close them in case they weren't closed already. This complicates
the API slightly, but it's a price we need to pay.

https://mattermost.atlassian.net/browse/MM-56402

```release-note
We close any remaining unclosed DB RPC connections
after a plugin shuts down.
```


Co-authored-by: Jesse Hallam <jesse.hallam@gmail.com>
Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Agniva De Sarker 2024-04-16 18:53:26 +05:30 committed by GitHub
parent 6aaabfb376
commit effb99301e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 20 deletions

View File

@ -16,6 +16,7 @@ import (
) )
func TestCreateChannelBookmark(t *testing.T) { func TestCreateChannelBookmark(t *testing.T) {
t.Skip("MM-57312")
os.Setenv("MM_FEATUREFLAGS_ChannelBookmarks", "true") os.Setenv("MM_FEATUREFLAGS_ChannelBookmarks", "true")
defer os.Unsetenv("MM_FEATUREFLAGS_ChannelBookmarks") defer os.Unsetenv("MM_FEATUREFLAGS_ChannelBookmarks")

View File

@ -12,6 +12,7 @@ import (
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin" "github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/shared/mlog"
) )
// DriverImpl implements the plugin.Driver interface on the server-side. // DriverImpl implements the plugin.Driver interface on the server-side.
@ -21,7 +22,7 @@ import (
type DriverImpl struct { type DriverImpl struct {
s *Server s *Server
connMut sync.RWMutex connMut sync.RWMutex
connMap map[string]*sql.Conn connMap map[string]*connMeta
txMut sync.Mutex txMut sync.Mutex
txMap map[string]driver.Tx txMap map[string]driver.Tx
stMut sync.RWMutex stMut sync.RWMutex
@ -30,10 +31,15 @@ type DriverImpl struct {
rowsMap map[string]driver.Rows rowsMap map[string]driver.Rows
} }
type connMeta struct {
pluginID string
conn *sql.Conn
}
func NewDriverImpl(s *Server) *DriverImpl { func NewDriverImpl(s *Server) *DriverImpl {
return &DriverImpl{ return &DriverImpl{
s: s, s: s,
connMap: make(map[string]*sql.Conn), connMap: make(map[string]*connMeta),
txMap: make(map[string]driver.Tx), txMap: make(map[string]driver.Tx),
stMap: make(map[string]driver.Stmt), stMap: make(map[string]driver.Stmt),
rowsMap: make(map[string]driver.Rows), rowsMap: make(map[string]driver.Rows),
@ -41,6 +47,14 @@ func NewDriverImpl(s *Server) *DriverImpl {
} }
func (d *DriverImpl) Conn(isMaster bool) (string, error) { func (d *DriverImpl) Conn(isMaster bool) (string, error) {
return d.conn(isMaster, "")
}
func (d *DriverImpl) ConnWithPluginID(isMaster bool, pluginID string) (string, error) {
return d.conn(isMaster, pluginID)
}
func (d *DriverImpl) conn(isMaster bool, pluginID string) (string, error) {
dbFunc := d.s.Platform().Store.GetInternalMasterDB dbFunc := d.s.Platform().Store.GetInternalMasterDB
if !isMaster { if !isMaster {
dbFunc = d.s.Platform().Store.GetInternalReplicaDB dbFunc = d.s.Platform().Store.GetInternalReplicaDB
@ -54,7 +68,7 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) {
} }
connID := model.NewId() connID := model.NewId()
d.connMut.Lock() d.connMut.Lock()
d.connMap[connID] = conn d.connMap[connID] = &connMeta{pluginID: pluginID, conn: conn}
d.connMut.Unlock() d.connMut.Unlock()
return connID, nil return connID, nil
} }
@ -70,13 +84,13 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) {
func (d *DriverImpl) ConnPing(connID string) error { func (d *DriverImpl) ConnPing(connID string) error {
d.connMut.RLock() d.connMut.RLock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
d.connMut.RUnlock() d.connMut.RUnlock()
if !ok { if !ok {
return driver.ErrBadConn return driver.ErrBadConn
} }
return conn.Raw(func(innerConn any) error { return entry.conn.Raw(func(innerConn any) error {
return innerConn.(driver.Pinger).Ping(context.Background()) return innerConn.(driver.Pinger).Ping(context.Background())
}) })
} }
@ -84,13 +98,13 @@ func (d *DriverImpl) ConnPing(connID string) error {
func (d *DriverImpl) ConnQuery(connID, q string, args []driver.NamedValue) (_ string, err error) { func (d *DriverImpl) ConnQuery(connID, q string, args []driver.NamedValue) (_ string, err error) {
var rows driver.Rows var rows driver.Rows
d.connMut.RLock() d.connMut.RLock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
d.connMut.RUnlock() d.connMut.RUnlock()
if !ok { if !ok {
return "", driver.ErrBadConn return "", driver.ErrBadConn
} }
err = conn.Raw(func(innerConn any) error { err = entry.conn.Raw(func(innerConn any) error {
rows, err = innerConn.(driver.QueryerContext).QueryContext(context.Background(), q, args) rows, err = innerConn.(driver.QueryerContext).QueryContext(context.Background(), q, args)
return err return err
}) })
@ -110,13 +124,13 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu
var res driver.Result var res driver.Result
var ret plugin.ResultContainer var ret plugin.ResultContainer
d.connMut.RLock() d.connMut.RLock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
d.connMut.RUnlock() d.connMut.RUnlock()
if !ok { if !ok {
return ret, driver.ErrBadConn return ret, driver.ErrBadConn
} }
err = conn.Raw(func(innerConn any) error { err = entry.conn.Raw(func(innerConn any) error {
res, err = innerConn.(driver.ExecerContext).ExecContext(context.Background(), q, args) res, err = innerConn.(driver.ExecerContext).ExecContext(context.Background(), q, args)
return err return err
}) })
@ -132,7 +146,7 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu
func (d *DriverImpl) ConnClose(connID string) error { func (d *DriverImpl) ConnClose(connID string) error {
d.connMut.Lock() d.connMut.Lock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
if !ok { if !ok {
d.connMut.Unlock() d.connMut.Unlock()
return driver.ErrBadConn return driver.ErrBadConn
@ -140,19 +154,35 @@ func (d *DriverImpl) ConnClose(connID string) error {
delete(d.connMap, connID) delete(d.connMap, connID)
d.connMut.Unlock() d.connMut.Unlock()
return conn.Close() return entry.conn.Close()
}
// ShutdownConns will close any remaining connections for the given pluginID
// if they weren't already closed by the plugin itself.
func (d *DriverImpl) ShutdownConns(pluginID string) {
d.connMut.Lock()
defer d.connMut.Unlock()
for connID, entry := range d.connMap {
if entry.pluginID == pluginID {
err := entry.conn.Close()
if err != nil {
d.s.Log().Error("Error while closing DB connection", mlog.Err(err), mlog.String("pluginID", pluginID))
}
delete(d.connMap, connID)
}
}
} }
func (d *DriverImpl) Tx(connID string, opts driver.TxOptions) (_ string, err error) { func (d *DriverImpl) Tx(connID string, opts driver.TxOptions) (_ string, err error) {
var tx driver.Tx var tx driver.Tx
d.connMut.RLock() d.connMut.RLock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
d.connMut.RUnlock() d.connMut.RUnlock()
if !ok { if !ok {
return "", driver.ErrBadConn return "", driver.ErrBadConn
} }
err = conn.Raw(func(innerConn any) error { err = entry.conn.Raw(func(innerConn any) error {
tx, err = innerConn.(driver.ConnBeginTx).BeginTx(context.Background(), opts) tx, err = innerConn.(driver.ConnBeginTx).BeginTx(context.Background(), opts)
return err return err
}) })
@ -188,13 +218,13 @@ func (d *DriverImpl) TxRollback(txID string) error {
func (d *DriverImpl) Stmt(connID, q string) (_ string, err error) { func (d *DriverImpl) Stmt(connID, q string) (_ string, err error) {
var stmt driver.Stmt var stmt driver.Stmt
d.connMut.RLock() d.connMut.RLock()
conn, ok := d.connMap[connID] entry, ok := d.connMap[connID]
d.connMut.RUnlock() d.connMut.RUnlock()
if !ok { if !ok {
return "", driver.ErrBadConn return "", driver.ErrBadConn
} }
err = conn.Raw(func(innerConn any) error { err = entry.conn.Raw(func(innerConn any) error {
stmt, err = innerConn.(driver.Conn).Prepare(q) stmt, err = innerConn.(driver.Conn).Prepare(q)
return err return err
}) })

View File

@ -19,3 +19,22 @@ func TestConnCreateTimeout(t *testing.T) {
_, err := d.Conn(true) _, err := d.Conn(true)
require.Error(t, err) require.Error(t, err)
} }
func TestShutdownPluginConns(t *testing.T) {
th := Setup(t)
defer th.TearDown()
d := NewDriverImpl(th.Server)
_, err := d.ConnWithPluginID(true, "plugin1")
require.NoError(t, err)
_, err = d.ConnWithPluginID(true, "plugin2")
require.NoError(t, err)
_, err = d.ConnWithPluginID(true, "plugin1")
require.NoError(t, err)
require.Len(t, d.connMap, 3)
d.ShutdownConns("plugin1")
require.Len(t, d.connMap, 1)
d.ShutdownConns("plugin2")
require.Len(t, d.connMap, 0)
}

View File

@ -62,3 +62,14 @@ type Driver interface {
// ResetSession(ctx context.Context) error // ResetSession(ctx context.Context) error
// IsValid() bool // IsValid() bool
} }
// AppDriver is an extension of the Driver interface to capture non-RPC APIs.
type AppDriver interface {
Driver
// ConnWithPluginID is only used by the server, and isn't exposed via the RPC API.
ConnWithPluginID(isMaster bool, pluginID string) (string, error)
// This is an extra method needed to shutdown connections
// after a plugin shuts down.
ShutdownConns(pluginID string)
}

View File

@ -54,7 +54,7 @@ type Environment struct {
logger *mlog.Logger logger *mlog.Logger
metrics metricsInterface metrics metricsInterface
newAPIImpl apiImplCreatorFunc newAPIImpl apiImplCreatorFunc
dbDriver Driver dbDriver AppDriver
pluginDir string pluginDir string
webappPluginDir string webappPluginDir string
prepackagedPlugins []*PrepackagedPlugin prepackagedPlugins []*PrepackagedPlugin
@ -64,7 +64,7 @@ type Environment struct {
func NewEnvironment( func NewEnvironment(
newAPIImpl apiImplCreatorFunc, newAPIImpl apiImplCreatorFunc,
dbDriver Driver, dbDriver AppDriver,
pluginDir string, pluginDir string,
webappPluginDir string, webappPluginDir string,
logger *mlog.Logger, logger *mlog.Logger,

View File

@ -24,6 +24,8 @@ import (
type supervisor struct { type supervisor struct {
lock sync.RWMutex lock sync.RWMutex
pluginID string
appDriver AppDriver
client *plugin.Client client *plugin.Client
hooks Hooks hooks Hooks
implemented [TotalHooksID]bool implemented [TotalHooksID]bool
@ -31,6 +33,15 @@ type supervisor struct {
isReattached bool isReattached bool
} }
type driverForPlugin struct {
AppDriver
pluginID string
}
func (d *driverForPlugin) Conn(isMaster bool) (string, error) {
return d.AppDriver.ConnWithPluginID(isMaster, d.pluginID)
}
func WithExecutableFromManifest(pluginInfo *model.BundleInfo) func(*supervisor, *plugin.ClientConfig) error { func WithExecutableFromManifest(pluginInfo *model.BundleInfo) func(*supervisor, *plugin.ClientConfig) error {
return func(_ *supervisor, clientConfig *plugin.ClientConfig) error { return func(_ *supervisor, clientConfig *plugin.ClientConfig) error {
executable := pluginInfo.Manifest.GetExecutableForRuntime(runtime.GOOS, runtime.GOARCH) executable := pluginInfo.Manifest.GetExecutableForRuntime(runtime.GOOS, runtime.GOARCH)
@ -74,8 +85,14 @@ func WithReattachConfig(pluginReattachConfig *model.PluginReattachConfig) func(*
} }
} }
func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) { func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver AppDriver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) {
sup := supervisor{} sup := supervisor{
pluginID: pluginInfo.Manifest.Id,
}
if driver != nil {
sup.appDriver = &driverForPlugin{AppDriver: driver, pluginID: pluginInfo.Manifest.Id}
}
defer func() { defer func() {
if retErr != nil { if retErr != nil {
sup.Shutdown() sup.Shutdown()
@ -92,7 +109,7 @@ func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, par
pluginMap := map[string]plugin.Plugin{ pluginMap := map[string]plugin.Plugin{
"hooks": &hooksPlugin{ "hooks": &hooksPlugin{
log: wrappedLogger, log: wrappedLogger,
driverImpl: driver, driverImpl: sup.appDriver,
apiImpl: &apiTimerLayer{pluginInfo.Manifest.Id, apiImpl, metrics}, apiImpl: &apiTimerLayer{pluginInfo.Manifest.Id, apiImpl, metrics},
}, },
} }
@ -166,8 +183,12 @@ func (sup *supervisor) Shutdown() {
} }
// Wait for API RPC server and DB RPC server to exit. // Wait for API RPC server and DB RPC server to exit.
// And then shutdown conns.
if sup.hooksClient != nil { if sup.hooksClient != nil {
sup.hooksClient.doneWg.Wait() sup.hooksClient.doneWg.Wait()
if sup.appDriver != nil {
sup.appDriver.ShutdownConns(sup.pluginID)
}
} }
} }