MM-56402: Introduce a pluginID to track RPC DB connections (#26424)

Previously, we relied on the plugin to close the DB connections
on shutdown. While this keeps the code simple, there is no guarantee
that the plugin author will remember to close the DB.

In that case, it's better to track the connections from the server side
and close them in case they weren't closed already. This complicates
the API slightly, but it's a price we need to pay.

https://mattermost.atlassian.net/browse/MM-56402

```release-note
We close any remaining unclosed DB RPC connections
after a plugin shuts down.
```


Co-authored-by: Jesse Hallam <jesse.hallam@gmail.com>
Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Agniva De Sarker 2024-04-16 18:53:26 +05:30 committed by GitHub
parent 6aaabfb376
commit effb99301e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 20 deletions

View File

@ -16,6 +16,7 @@ import (
)
func TestCreateChannelBookmark(t *testing.T) {
t.Skip("MM-57312")
os.Setenv("MM_FEATUREFLAGS_ChannelBookmarks", "true")
defer os.Unsetenv("MM_FEATUREFLAGS_ChannelBookmarks")

View File

@ -12,6 +12,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// DriverImpl implements the plugin.Driver interface on the server-side.
@ -21,7 +22,7 @@ import (
type DriverImpl struct {
s *Server
connMut sync.RWMutex
connMap map[string]*sql.Conn
connMap map[string]*connMeta
txMut sync.Mutex
txMap map[string]driver.Tx
stMut sync.RWMutex
@ -30,10 +31,15 @@ type DriverImpl struct {
rowsMap map[string]driver.Rows
}
type connMeta struct {
pluginID string
conn *sql.Conn
}
func NewDriverImpl(s *Server) *DriverImpl {
return &DriverImpl{
s: s,
connMap: make(map[string]*sql.Conn),
connMap: make(map[string]*connMeta),
txMap: make(map[string]driver.Tx),
stMap: make(map[string]driver.Stmt),
rowsMap: make(map[string]driver.Rows),
@ -41,6 +47,14 @@ func NewDriverImpl(s *Server) *DriverImpl {
}
func (d *DriverImpl) Conn(isMaster bool) (string, error) {
return d.conn(isMaster, "")
}
func (d *DriverImpl) ConnWithPluginID(isMaster bool, pluginID string) (string, error) {
return d.conn(isMaster, pluginID)
}
func (d *DriverImpl) conn(isMaster bool, pluginID string) (string, error) {
dbFunc := d.s.Platform().Store.GetInternalMasterDB
if !isMaster {
dbFunc = d.s.Platform().Store.GetInternalReplicaDB
@ -54,7 +68,7 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) {
}
connID := model.NewId()
d.connMut.Lock()
d.connMap[connID] = conn
d.connMap[connID] = &connMeta{pluginID: pluginID, conn: conn}
d.connMut.Unlock()
return connID, nil
}
@ -70,13 +84,13 @@ func (d *DriverImpl) Conn(isMaster bool) (string, error) {
func (d *DriverImpl) ConnPing(connID string) error {
d.connMut.RLock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
d.connMut.RUnlock()
if !ok {
return driver.ErrBadConn
}
return conn.Raw(func(innerConn any) error {
return entry.conn.Raw(func(innerConn any) error {
return innerConn.(driver.Pinger).Ping(context.Background())
})
}
@ -84,13 +98,13 @@ func (d *DriverImpl) ConnPing(connID string) error {
func (d *DriverImpl) ConnQuery(connID, q string, args []driver.NamedValue) (_ string, err error) {
var rows driver.Rows
d.connMut.RLock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
d.connMut.RUnlock()
if !ok {
return "", driver.ErrBadConn
}
err = conn.Raw(func(innerConn any) error {
err = entry.conn.Raw(func(innerConn any) error {
rows, err = innerConn.(driver.QueryerContext).QueryContext(context.Background(), q, args)
return err
})
@ -110,13 +124,13 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu
var res driver.Result
var ret plugin.ResultContainer
d.connMut.RLock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
d.connMut.RUnlock()
if !ok {
return ret, driver.ErrBadConn
}
err = conn.Raw(func(innerConn any) error {
err = entry.conn.Raw(func(innerConn any) error {
res, err = innerConn.(driver.ExecerContext).ExecContext(context.Background(), q, args)
return err
})
@ -132,7 +146,7 @@ func (d *DriverImpl) ConnExec(connID, q string, args []driver.NamedValue) (_ plu
func (d *DriverImpl) ConnClose(connID string) error {
d.connMut.Lock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
if !ok {
d.connMut.Unlock()
return driver.ErrBadConn
@ -140,19 +154,35 @@ func (d *DriverImpl) ConnClose(connID string) error {
delete(d.connMap, connID)
d.connMut.Unlock()
return conn.Close()
return entry.conn.Close()
}
// ShutdownConns will close any remaining connections for the given pluginID
// if they weren't already closed by the plugin itself.
func (d *DriverImpl) ShutdownConns(pluginID string) {
d.connMut.Lock()
defer d.connMut.Unlock()
for connID, entry := range d.connMap {
if entry.pluginID == pluginID {
err := entry.conn.Close()
if err != nil {
d.s.Log().Error("Error while closing DB connection", mlog.Err(err), mlog.String("pluginID", pluginID))
}
delete(d.connMap, connID)
}
}
}
func (d *DriverImpl) Tx(connID string, opts driver.TxOptions) (_ string, err error) {
var tx driver.Tx
d.connMut.RLock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
d.connMut.RUnlock()
if !ok {
return "", driver.ErrBadConn
}
err = conn.Raw(func(innerConn any) error {
err = entry.conn.Raw(func(innerConn any) error {
tx, err = innerConn.(driver.ConnBeginTx).BeginTx(context.Background(), opts)
return err
})
@ -188,13 +218,13 @@ func (d *DriverImpl) TxRollback(txID string) error {
func (d *DriverImpl) Stmt(connID, q string) (_ string, err error) {
var stmt driver.Stmt
d.connMut.RLock()
conn, ok := d.connMap[connID]
entry, ok := d.connMap[connID]
d.connMut.RUnlock()
if !ok {
return "", driver.ErrBadConn
}
err = conn.Raw(func(innerConn any) error {
err = entry.conn.Raw(func(innerConn any) error {
stmt, err = innerConn.(driver.Conn).Prepare(q)
return err
})

View File

@ -19,3 +19,22 @@ func TestConnCreateTimeout(t *testing.T) {
_, err := d.Conn(true)
require.Error(t, err)
}
func TestShutdownPluginConns(t *testing.T) {
th := Setup(t)
defer th.TearDown()
d := NewDriverImpl(th.Server)
_, err := d.ConnWithPluginID(true, "plugin1")
require.NoError(t, err)
_, err = d.ConnWithPluginID(true, "plugin2")
require.NoError(t, err)
_, err = d.ConnWithPluginID(true, "plugin1")
require.NoError(t, err)
require.Len(t, d.connMap, 3)
d.ShutdownConns("plugin1")
require.Len(t, d.connMap, 1)
d.ShutdownConns("plugin2")
require.Len(t, d.connMap, 0)
}

View File

@ -62,3 +62,14 @@ type Driver interface {
// ResetSession(ctx context.Context) error
// IsValid() bool
}
// AppDriver is an extension of the Driver interface to capture non-RPC APIs.
type AppDriver interface {
Driver
// ConnWithPluginID is only used by the server, and isn't exposed via the RPC API.
ConnWithPluginID(isMaster bool, pluginID string) (string, error)
// This is an extra method needed to shutdown connections
// after a plugin shuts down.
ShutdownConns(pluginID string)
}

View File

@ -54,7 +54,7 @@ type Environment struct {
logger *mlog.Logger
metrics metricsInterface
newAPIImpl apiImplCreatorFunc
dbDriver Driver
dbDriver AppDriver
pluginDir string
webappPluginDir string
prepackagedPlugins []*PrepackagedPlugin
@ -64,7 +64,7 @@ type Environment struct {
func NewEnvironment(
newAPIImpl apiImplCreatorFunc,
dbDriver Driver,
dbDriver AppDriver,
pluginDir string,
webappPluginDir string,
logger *mlog.Logger,

View File

@ -24,6 +24,8 @@ import (
type supervisor struct {
lock sync.RWMutex
pluginID string
appDriver AppDriver
client *plugin.Client
hooks Hooks
implemented [TotalHooksID]bool
@ -31,6 +33,15 @@ type supervisor struct {
isReattached bool
}
type driverForPlugin struct {
AppDriver
pluginID string
}
func (d *driverForPlugin) Conn(isMaster bool) (string, error) {
return d.AppDriver.ConnWithPluginID(isMaster, d.pluginID)
}
func WithExecutableFromManifest(pluginInfo *model.BundleInfo) func(*supervisor, *plugin.ClientConfig) error {
return func(_ *supervisor, clientConfig *plugin.ClientConfig) error {
executable := pluginInfo.Manifest.GetExecutableForRuntime(runtime.GOOS, runtime.GOARCH)
@ -74,8 +85,14 @@ func WithReattachConfig(pluginReattachConfig *model.PluginReattachConfig) func(*
}
}
func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) {
sup := supervisor{}
func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver AppDriver, parentLogger *mlog.Logger, metrics metricsInterface, opts ...func(*supervisor, *plugin.ClientConfig) error) (retSupervisor *supervisor, retErr error) {
sup := supervisor{
pluginID: pluginInfo.Manifest.Id,
}
if driver != nil {
sup.appDriver = &driverForPlugin{AppDriver: driver, pluginID: pluginInfo.Manifest.Id}
}
defer func() {
if retErr != nil {
sup.Shutdown()
@ -92,7 +109,7 @@ func newSupervisor(pluginInfo *model.BundleInfo, apiImpl API, driver Driver, par
pluginMap := map[string]plugin.Plugin{
"hooks": &hooksPlugin{
log: wrappedLogger,
driverImpl: driver,
driverImpl: sup.appDriver,
apiImpl: &apiTimerLayer{pluginInfo.Manifest.Id, apiImpl, metrics},
},
}
@ -166,8 +183,12 @@ func (sup *supervisor) Shutdown() {
}
// Wait for API RPC server and DB RPC server to exit.
// And then shutdown conns.
if sup.hooksClient != nil {
sup.hooksClient.doneWg.Wait()
if sup.appDriver != nil {
sup.appDriver.ShutdownConns(sup.pluginID)
}
}
}