PLT-7407: Back-end plugins (#7409)

* tie back-end plugins together

* fix comment typo

* add tests and a bit of polish

* tests and polish

* add test, don't let backend executable paths escape the plugin directory
This commit is contained in:
Chris
2017-09-11 10:02:02 -05:00
committed by GitHub
parent a69bed712d
commit 402491b7e5
19 changed files with 655 additions and 152 deletions

View File

@@ -150,7 +150,7 @@ func InitApi(full bool) {
BaseRoutes.PublicFile = BaseRoutes.Root.PathPrefix("/files/{file_id:[A-Za-z0-9]+}/public").Subrouter()
BaseRoutes.Plugins = BaseRoutes.ApiRoot.PathPrefix("/plugins").Subrouter()
BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-]+}").Subrouter()
BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}").Subrouter()
BaseRoutes.Commands = BaseRoutes.ApiRoot.PathPrefix("/commands").Subrouter()
BaseRoutes.Command = BaseRoutes.Commands.PathPrefix("/{command_id:[A-Za-z0-9]+}").Subrouter()

View File

@@ -110,5 +110,5 @@ func TestPlugin(t *testing.T) {
_, resp = th.SystemAdminClient.RemovePlugin("bad.id")
CheckNotFoundStatus(t, resp)
th.App.Srv.PluginEnv = nil
th.App.PluginEnv = nil
}

View File

@@ -6,10 +6,13 @@ package app
import (
"io/ioutil"
"net/http"
"github.com/mattermost/mattermost-server/plugin/pluginenv"
)
type App struct {
Srv *Server
Srv *Server
PluginEnv *pluginenv.Environment
}
var globalApp App

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"encoding/json"
"io"
"io/ioutil"
@@ -19,14 +20,17 @@ import (
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/utils"
"github.com/mattermost/mattermost-server/app/plugin"
builtinplugin "github.com/mattermost/mattermost-server/app/plugin"
"github.com/mattermost/mattermost-server/app/plugin/jira"
"github.com/mattermost/mattermost-server/app/plugin/ldapextras"
"github.com/mattermost/mattermost-server/plugin"
"github.com/mattermost/mattermost-server/plugin/pluginenv"
)
type PluginAPI struct {
id string
router *mux.Router
id string
app *App
}
func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error {
@@ -37,37 +41,67 @@ func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error {
}
}
func (api *PluginAPI) PluginRouter() *mux.Router {
return api.router
}
func (api *PluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) {
return Global().GetTeamByName(name)
return api.app.GetTeamByName(name)
}
func (api *PluginAPI) GetUserByName(name string) (*model.User, *model.AppError) {
return Global().GetUserByUsername(name)
func (api *PluginAPI) GetUserByUsername(name string) (*model.User, *model.AppError) {
return api.app.GetUserByUsername(name)
}
func (api *PluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) {
return Global().GetChannelByName(name, teamId)
}
func (api *PluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
return Global().GetDirectChannel(userId1, userId2)
func (api *PluginAPI) GetChannelByName(name, teamId string) (*model.Channel, *model.AppError) {
return api.app.GetChannelByName(name, teamId)
}
func (api *PluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) {
return Global().CreatePostMissingChannel(post, true)
return api.app.CreatePostMissingChannel(post, true)
}
func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) {
type BuiltInPluginAPI struct {
id string
router *mux.Router
app *App
}
func (api *BuiltInPluginAPI) LoadPluginConfiguration(dest interface{}) error {
if b, err := json.Marshal(utils.Cfg.PluginSettings.Plugins[api.id]); err != nil {
return err
} else {
return json.Unmarshal(b, dest)
}
}
func (api *BuiltInPluginAPI) PluginRouter() *mux.Router {
return api.router
}
func (api *BuiltInPluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) {
return api.app.GetTeamByName(name)
}
func (api *BuiltInPluginAPI) GetUserByName(name string) (*model.User, *model.AppError) {
return api.app.GetUserByUsername(name)
}
func (api *BuiltInPluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) {
return api.app.GetChannelByName(name, teamId)
}
func (api *BuiltInPluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
return api.app.GetDirectChannel(userId1, userId2)
}
func (api *BuiltInPluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) {
return api.app.CreatePostMissingChannel(post, true)
}
func (api *BuiltInPluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) {
ldapInterface := einterfaces.GetLdapInterface()
if ldapInterface == nil {
return nil, model.NewAppError("GetLdapUserAttributes", "ent.ldap.disabled.app_error", nil, "", http.StatusNotImplemented)
}
user, err := Global().GetUser(userId)
user, err := api.app.GetUser(userId)
if err != nil {
return nil, err
}
@@ -75,7 +109,7 @@ func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string)
return ldapInterface.GetUserAttributes(*user.AuthData, attributes)
}
func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) {
func (api *BuiltInPluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) {
token := ""
isTokenFromQueryString := false
@@ -111,7 +145,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m
return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
}
session, err := Global().GetSession(token)
session, err := api.app.GetSession(token)
if err != nil {
return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
@@ -122,7 +156,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m
return session, nil
}
func (api *PluginAPI) I18n(id string, r *http.Request) string {
func (api *BuiltInPluginAPI) I18n(id string, r *http.Request) string {
if r != nil {
f, _ := utils.GetTranslationsAndLocale(nil, r)
return f(id)
@@ -131,16 +165,17 @@ func (api *PluginAPI) I18n(id string, r *http.Request) string {
return f(id)
}
func (a *App) InitPlugins() {
plugins := map[string]plugin.Plugin{
func (a *App) InitBuiltInPlugins() {
plugins := map[string]builtinplugin.Plugin{
"jira": &jira.Plugin{},
"ldapextras": &ldapextras.Plugin{},
}
for id, p := range plugins {
l4g.Info("Initializing plugin: " + id)
api := &PluginAPI{
api := &BuiltInPluginAPI{
id: id,
router: a.Srv.Router.PathPrefix("/plugins/" + id).Subrouter(),
app: a,
}
p.Initialize(api)
}
@@ -155,19 +190,19 @@ func (a *App) InitPlugins() {
}
func (a *App) ActivatePlugins() {
if a.Srv.PluginEnv == nil {
if a.PluginEnv == nil {
l4g.Error("plugin env not initialized")
return
}
plugins, err := a.Srv.PluginEnv.Plugins()
plugins, err := a.PluginEnv.Plugins()
if err != nil {
l4g.Error("failed to start up plugins: " + err.Error())
return
}
for _, plugin := range plugins {
err := a.Srv.PluginEnv.ActivatePlugin(plugin.Manifest.Id)
err := a.PluginEnv.ActivatePlugin(plugin.Manifest.Id)
if err != nil {
l4g.Error(err.Error())
}
@@ -176,48 +211,43 @@ func (a *App) ActivatePlugins() {
}
func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *model.AppError) {
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
}
tmpDir, err := ioutil.TempDir("", "plugintmp")
if err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.temp_dir.app_error", nil, err.Error(), http.StatusInternalServerError)
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError)
}
defer func() {
os.RemoveAll(tmpDir)
}()
defer os.RemoveAll(tmpDir)
filenames, err := utils.ExtractTarGz(pluginFile, tmpDir)
if err != nil {
if err := utils.ExtractTarGz(pluginFile, tmpDir); err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.extract.app_error", nil, err.Error(), http.StatusBadRequest)
}
if len(filenames) == 0 {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.no_files.app_error", nil, err.Error(), http.StatusBadRequest)
tmpPluginDir := tmpDir
dir, err := ioutil.ReadDir(tmpDir)
if err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError)
}
splitPath := strings.Split(filenames[0], string(os.PathSeparator))
if len(splitPath) == 0 {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.bad_path.app_error", nil, err.Error(), http.StatusBadRequest)
if len(dir) == 1 && dir[0].IsDir() {
tmpPluginDir = filepath.Join(tmpPluginDir, dir[0].Name())
}
manifestDir := filepath.Join(tmpDir, splitPath[0])
manifest, _, err := model.FindManifest(manifestDir)
manifest, _, err := model.FindManifest(tmpPluginDir)
if err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.manifest.app_error", nil, err.Error(), http.StatusBadRequest)
}
os.Rename(manifestDir, filepath.Join(a.Srv.PluginEnv.SearchPath(), manifest.Id))
os.Rename(tmpPluginDir, filepath.Join(a.PluginEnv.SearchPath(), manifest.Id))
if err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.mvdir.app_error", nil, err.Error(), http.StatusInternalServerError)
}
// Should add manifest validation and error handling here
err = a.Srv.PluginEnv.ActivatePlugin(manifest.Id)
err = a.PluginEnv.ActivatePlugin(manifest.Id)
if err != nil {
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.activate.app_error", nil, err.Error(), http.StatusBadRequest)
}
@@ -226,11 +256,11 @@ func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *m
}
func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) {
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
}
plugins, err := a.Srv.PluginEnv.ActivePlugins()
plugins, err := a.PluginEnv.ActivePlugins()
if err != nil {
return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.get_plugins.app_error", nil, err.Error(), http.StatusInternalServerError)
}
@@ -244,16 +274,16 @@ func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) {
}
func (a *App) RemovePlugin(id string) *model.AppError {
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
return model.NewAppError("RemovePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
}
err := a.Srv.PluginEnv.DeactivatePlugin(id)
err := a.PluginEnv.DeactivatePlugin(id)
if err != nil {
return model.NewAppError("RemovePlugin", "app.plugin.deactivate.app_error", nil, err.Error(), http.StatusBadRequest)
}
err = os.RemoveAll(filepath.Join(a.Srv.PluginEnv.SearchPath(), id))
err = os.RemoveAll(filepath.Join(a.PluginEnv.SearchPath(), id))
if err != nil {
return model.NewAppError("RemovePlugin", "app.plugin.remove.app_error", nil, err.Error(), http.StatusInternalServerError)
}
@@ -268,11 +298,11 @@ type ClientConfigPlugin struct {
}
func (a *App) GetPluginsForClientConfig() string {
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
return ""
}
plugins, err := a.Srv.PluginEnv.ActivePlugins()
plugins, err := a.PluginEnv.ActivePlugins()
if err != nil {
return ""
}
@@ -292,3 +322,94 @@ func (a *App) GetPluginsForClientConfig() string {
return string(b)
}
func (a *App) InitPlugins(pluginPath, webappPath string) {
a.InitBuiltInPlugins()
if !utils.IsLicensed() || !*utils.License().Features.FutureFeatures || !*utils.Cfg.PluginSettings.Enable {
return
}
l4g.Info("Starting up plugins")
err := os.Mkdir(pluginPath, 0744)
if err != nil && !os.IsExist(err) {
l4g.Error("failed to start up plugins: " + err.Error())
return
}
a.PluginEnv, err = pluginenv.New(
pluginenv.SearchPath(pluginPath),
pluginenv.WebappPath(webappPath),
pluginenv.APIProvider(func(m *model.Manifest) (plugin.API, error) {
return &PluginAPI{
id: m.Id,
app: a,
}, nil
}),
)
if err != nil {
l4g.Error("failed to start up plugins: " + err.Error())
return
}
utils.AddConfigListener(func(_, _ *model.Config) {
for _, err := range a.PluginEnv.Hooks().OnConfigurationChange() {
l4g.Error(err.Error())
}
})
a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}", a.ServePluginRequest)
a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}/{anything:.*}", a.ServePluginRequest)
a.ActivatePlugins()
}
func (a *App) ServePluginRequest(w http.ResponseWriter, r *http.Request) {
token := ""
authHeader := r.Header.Get(model.HEADER_AUTH)
if strings.HasPrefix(strings.ToUpper(authHeader), model.HEADER_BEARER+":") {
token = authHeader[len(model.HEADER_BEARER)+1:]
} else if strings.HasPrefix(strings.ToLower(authHeader), model.HEADER_TOKEN+":") {
token = authHeader[len(model.HEADER_TOKEN)+1:]
} else if cookie, _ := r.Cookie(model.SESSION_COOKIE_TOKEN); cookie != nil && (r.Method == "GET" || r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML) {
token = cookie.Value
} else {
token = r.URL.Query().Get("access_token")
}
r.Header.Del("Mattermost-User-Id")
if token != "" {
if session, err := a.GetSession(token); err != nil {
r.Header.Set("Mattermost-User-Id", session.UserId)
}
}
cookies := r.Cookies()
r.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != model.SESSION_COOKIE_TOKEN {
r.AddCookie(c)
}
}
r.Header.Del(model.HEADER_AUTH)
r.Header.Del("Referer")
newQuery := r.URL.Query()
newQuery.Del("access_token")
r.URL.RawQuery = newQuery.Encode()
params := mux.Vars(r)
a.PluginEnv.Hooks().ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "plugin_id", params["plugin_id"])))
}
func (a *App) ShutDownPlugins() {
if a.PluginEnv == nil {
return
}
for _, err := range a.PluginEnv.Shutdown() {
l4g.Error(err.Error())
}
}

View File

@@ -7,7 +7,6 @@ import (
"crypto/tls"
"net"
"net/http"
"os"
"strings"
"time"
@@ -20,7 +19,6 @@ import (
"gopkg.in/throttled/throttled.v2/store/memstore"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/plugin/pluginenv"
"github.com/mattermost/mattermost-server/store"
"github.com/mattermost/mattermost-server/utils"
)
@@ -30,7 +28,6 @@ type Server struct {
WebSocketRouter *WebSocketRouter
Router *mux.Router
GracefulServer *graceful.Server
PluginEnv *pluginenv.Environment
}
var allowedMethods []string = []string{
@@ -187,10 +184,6 @@ func (a *App) StartServer() {
}()
}
if utils.IsLicensed() && *utils.License().Features.FutureFeatures && *utils.Cfg.PluginSettings.Enable {
a.StartupPlugins("plugins", "webapp/dist")
}
go func() {
var err error
if *utils.Cfg.ServiceSettings.ConnectionSecurity == model.CONN_SECURITY_TLS {
@@ -226,30 +219,7 @@ func (a *App) StopServer() {
a.Srv.Store.Close()
HubStop()
a.ShutDownPlugins()
l4g.Info(utils.T("api.server.stop_server.stopped.info"))
}
func (a *App) StartupPlugins(pluginPath, webappPath string) {
l4g.Info("Starting up plugins")
err := os.Mkdir(pluginPath, 0744)
if err != nil {
if os.IsExist(err) {
err = nil
} else {
l4g.Error("failed to start up plugins: " + err.Error())
return
}
}
a.Srv.PluginEnv, err = pluginenv.New(
pluginenv.SearchPath(pluginPath),
pluginenv.WebappPath(webappPath),
)
if err != nil {
l4g.Error("failed to start up plugins: " + err.Error())
}
a.ActivatePlugins()
}

View File

@@ -71,20 +71,22 @@ func runServer(configFileLocation string) {
l4g.Error("Problem with file storage settings: " + err.Error())
}
app.Global().NewServer()
app.Global().InitStores()
a := app.Global()
a.NewServer()
a.InitStores()
api.InitRouter()
if model.BuildEnterpriseReady == "true" {
a.LoadLicense()
}
a.InitPlugins("plugins", "webapp/dist")
wsapi.InitRouter()
api4.InitApi(false)
api.InitApi()
app.Global().InitPlugins()
wsapi.InitApi()
web.InitWeb()
if model.BuildEnterpriseReady == "true" {
app.Global().LoadLicense()
}
if !utils.IsLicensed() && len(utils.Cfg.SqlSettings.DataSourceReplicas) > 1 {
l4g.Warn(utils.T("store.sql.read_replicas_not_licensed.critical"))
utils.Cfg.SqlSettings.DataSourceReplicas = utils.Cfg.SqlSettings.DataSourceReplicas[:1]
@@ -98,7 +100,7 @@ func runServer(configFileLocation string) {
resetStatuses()
app.Global().StartServer()
a.StartServer()
// If we allow testing then listen for manual testing URL hits
if utils.Cfg.ServiceSettings.EnableTesting {
@@ -118,7 +120,7 @@ func runServer(configFileLocation string) {
}
if einterfaces.GetClusterInterface() != nil {
app.Global().RegisterAllClusterMessageHandlers()
a.RegisterAllClusterMessageHandlers()
einterfaces.GetClusterInterface().StartInterNodeCommunication()
}
@@ -132,7 +134,7 @@ func runServer(configFileLocation string) {
}
}
jobs.Srv.Store = app.Global().Srv.Store
jobs.Srv.Store = a.Srv.Store
if *utils.Cfg.JobSettings.RunJobs {
jobs.Srv.StartWorkers()
}
@@ -157,7 +159,7 @@ func runServer(configFileLocation string) {
jobs.Srv.StopSchedulers()
jobs.Srv.StopWorkers()
app.Global().StopServer()
a.StopServer()
}
func runSecurityJob() {

View File

@@ -3427,10 +3427,6 @@
"id": "app.plugin.activate.app_error",
"translation": "Unable to activate extracted plugin. Plugin may already exist and be activated."
},
{
"id": "app.plugin.bad_path.app_error",
"translation": "Bad file path in extracted files"
},
{
"id": "app.plugin.deactivate.app_error",
"translation": "Unable to deactivate plugin"
@@ -3443,6 +3439,10 @@
"id": "app.plugin.extract.app_error",
"translation": "Encountered error extracting plugin"
},
{
"id": "app.plugin.filesystem.app_error",
"translation": "Encountered filesystem error"
},
{
"id": "app.plugin.get_plugins.app_error",
"translation": "Unable to get active plugins"
@@ -3455,10 +3455,6 @@
"id": "app.plugin.mvdir.app_error",
"translation": "Unable to move plugin from temporary directory to final destination"
},
{
"id": "app.plugin.no_files.app_error",
"translation": "No files found in the compressed folder"
},
{
"id": "app.plugin.remove.app_error",
"translation": "Unable to delete plugin"

View File

@@ -12,6 +12,9 @@ type Hooks interface {
// use the API, and the plugin will be terminated shortly after this invocation.
OnDeactivate() error
// OnConfigurationChange is invoked when configuration changes may have been made.
OnConfigurationChange() error
// ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for
// the /plugins/{id} path will be routed to the plugin.
//

View File

@@ -4,6 +4,7 @@ package pluginenv
import (
"fmt"
"io/ioutil"
"net/http"
"sync"
"github.com/pkg/errors"
@@ -27,7 +28,7 @@ type Environment struct {
apiProvider APIProviderFunc
supervisorProvider SupervisorProviderFunc
activePlugins map[string]ActivePlugin
mutex sync.Mutex
mutex sync.RWMutex
}
type Option func(*Environment)
@@ -61,15 +62,13 @@ func (env *Environment) SearchPath() string {
// Returns a list of all plugins found within the environment.
func (env *Environment) Plugins() ([]*model.BundleInfo, error) {
env.mutex.Lock()
defer env.mutex.Unlock()
return ScanSearchPath(env.searchPath)
}
// Returns a list of all currently active plugins within the environment.
func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
env.mutex.Lock()
defer env.mutex.Unlock()
env.mutex.RLock()
defer env.mutex.RUnlock()
activePlugins := []*model.BundleInfo{}
for _, p := range env.activePlugins {
@@ -81,8 +80,8 @@ func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
// Returns the ids of the currently active plugins.
func (env *Environment) ActivePluginIds() (ids []string) {
env.mutex.Lock()
defer env.mutex.Unlock()
env.mutex.RLock()
defer env.mutex.RUnlock()
for id := range env.activePlugins {
ids = append(ids, id)
@@ -200,13 +199,55 @@ func (env *Environment) Shutdown() (errs []error) {
for _, activePlugin := range env.activePlugins {
if activePlugin.Supervisor != nil {
if err := activePlugin.Supervisor.Hooks().OnDeactivate(); err != nil {
errs = append(errs, err)
errs = append(errs, errors.Wrapf(err, "OnDeactivate() error for %v", activePlugin.BundleInfo.Manifest.Id))
}
if err := activePlugin.Supervisor.Stop(); err != nil {
errs = append(errs, err)
errs = append(errs, errors.Wrapf(err, "error stopping supervisor for %v", activePlugin.BundleInfo.Manifest.Id))
}
}
}
env.activePlugins = make(map[string]ActivePlugin)
return
}
type EnvironmentHooks struct {
env *Environment
}
func (env *Environment) Hooks() *EnvironmentHooks {
return &EnvironmentHooks{env}
}
// OnConfigurationChange invokes the OnConfigurationChange hook for all plugins. Any errors
// encountered will be returned.
func (h *EnvironmentHooks) OnConfigurationChange() (errs []error) {
h.env.mutex.RLock()
defer h.env.mutex.RUnlock()
for _, activePlugin := range h.env.activePlugins {
if activePlugin.Supervisor == nil {
continue
}
if err := activePlugin.Supervisor.Hooks().OnConfigurationChange(); err != nil {
errs = append(errs, errors.Wrapf(err, "OnConfigurationChange error for %v", activePlugin.BundleInfo.Manifest.Id))
}
}
return
}
// ServeHTTP invokes the ServeHTTP hook for the plugin identified by the request or responds with a
// 404 not found.
//
// It expects the request's context to have a plugin_id set.
func (h *EnvironmentHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if id := r.Context().Value("plugin_id"); id != nil {
if idstr, ok := id.(string); ok {
h.env.mutex.RLock()
defer h.env.mutex.RUnlock()
if plugin, ok := h.env.activePlugins[idstr]; ok && plugin.Supervisor != nil {
plugin.Supervisor.Hooks().ServeHTTP(w, r)
return
}
}
}
http.NotFound(w, r)
}

View File

@@ -1,10 +1,14 @@
package pluginenv
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -298,3 +302,70 @@ func TestEnvironment_ShutdownError(t *testing.T) {
assert.Equal(t, env.ActivePluginIds(), []string{"foo"})
assert.Len(t, env.Shutdown(), 2)
}
func TestEnvironment_ConcurrentHookInvocations(t *testing.T) {
dir := initTmpDir(t, map[string]string{
"foo/plugin.json": `{"id": "foo", "backend": {}}`,
})
defer os.RemoveAll(dir)
var provider MockProvider
defer provider.AssertExpectations(t)
var api struct{ plugin.API }
var supervisor MockSupervisor
defer supervisor.AssertExpectations(t)
var hooks plugintest.Hooks
defer hooks.AssertExpectations(t)
env, err := New(
SearchPath(dir),
APIProvider(provider.API),
SupervisorProvider(provider.Supervisor),
)
require.NoError(t, err)
defer env.Shutdown()
provider.On("API").Return(&api, nil)
provider.On("Supervisor").Return(&supervisor, nil)
supervisor.On("Start").Return(nil)
supervisor.On("Stop").Return(nil)
supervisor.On("Hooks").Return(&hooks)
ch := make(chan bool)
hooks.On("OnActivate", &api).Return(nil)
hooks.On("OnDeactivate").Return(nil)
hooks.On("ServeHTTP", mock.AnythingOfType("*httptest.ResponseRecorder"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
r := args.Get(1).(*http.Request)
if r.URL.Path == "/1" {
<-ch
} else {
ch <- true
}
})
assert.NoError(t, env.ActivatePlugin("foo"))
rec := httptest.NewRecorder()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
req, err := http.NewRequest("GET", "/1", nil)
require.NoError(t, err)
env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
wg.Done()
}()
go func() {
req, err := http.NewRequest("GET", "/2", nil)
require.NoError(t, err)
env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
wg.Done()
}()
wg.Wait()
}

View File

@@ -22,6 +22,10 @@ func (m *Hooks) OnDeactivate() error {
return m.Called().Error(0)
}
func (m *Hooks) OnConfigurationChange() error {
return m.Called().Error(0)
}
func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
}

View File

@@ -86,6 +86,15 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) {
return
}
func (h *LocalHooks) OnConfigurationChange(args, reply *struct{}) error {
if hook, ok := h.hooks.(interface {
OnConfigurationChange() error
}); ok {
return hook.OnConfigurationChange()
}
return nil
}
type ServeHTTPArgs struct {
ResponseWriterStream int64
Request *http.Request
@@ -122,11 +131,14 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) {
server.ServeConn(conn)
}
// These assignments are part of the wire protocol. You can add more, but should not change existing
// assignments.
const (
remoteOnActivate = iota
remoteOnDeactivate
remoteServeHTTP
maxRemoteHookCount
remoteOnActivate = 0
remoteOnDeactivate = 1
remoteServeHTTP = 2
remoteOnConfigurationChange = 3
maxRemoteHookCount = iota
)
type RemoteHooks struct {
@@ -164,6 +176,13 @@ func (h *RemoteHooks) OnDeactivate() error {
return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil)
}
func (h *RemoteHooks) OnConfigurationChange() error {
if !h.implemented[remoteOnConfigurationChange] {
return nil
}
return h.client.Call("LocalHooks.OnConfigurationChange", struct{}{}, nil)
}
func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !h.implemented[remoteServeHTTP] {
http.NotFound(w, r)
@@ -227,6 +246,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) {
remote.implemented[remoteOnActivate] = true
case "OnDeactivate":
remote.implemented[remoteOnDeactivate] = true
case "OnConfigurationChange":
remote.implemented[remoteOnConfigurationChange] = true
case "ServeHTTP":
remote.implemented[remoteServeHTTP] = true
}

View File

@@ -6,10 +6,12 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost-server/plugin"
"github.com/mattermost/mattermost-server/plugin/plugintest"
@@ -50,6 +52,9 @@ func TestHooks(t *testing.T) {
hooks.On("OnDeactivate").Return(nil)
assert.NoError(t, remote.OnDeactivate())
hooks.On("OnConfigurationChange").Return(nil)
assert.NoError(t, remote.OnConfigurationChange())
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
w := args.Get(0).(http.ResponseWriter)
r := args.Get(1).(*http.Request)
@@ -77,6 +82,45 @@ func TestHooks(t *testing.T) {
}))
}
func TestHooks_Concurrency(t *testing.T) {
var hooks plugintest.Hooks
defer hooks.AssertExpectations(t)
assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
ch := make(chan bool)
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
r := args.Get(1).(*http.Request)
if r.URL.Path == "/1" {
<-ch
} else {
ch <- true
}
})
rec := httptest.NewRecorder()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
req, err := http.NewRequest("GET", "/1", nil)
require.NoError(t, err)
remote.ServeHTTP(rec, req)
wg.Done()
}()
go func() {
req, err := http.NewRequest("GET", "/2", nil)
require.NoError(t, err)
remote.ServeHTTP(rec, req)
wg.Done()
}()
wg.Wait()
}))
}
type testHooks struct {
mock.Mock
}

View File

@@ -2,26 +2,169 @@ package rpcplugin
import (
"bufio"
"bytes"
"encoding/binary"
"io"
"os"
"sync"
)
type asyncRead struct {
b []byte
err error
}
type asyncReadCloser struct {
io.ReadCloser
buffer bytes.Buffer
read chan struct{}
reads chan asyncRead
close chan struct{}
closeOnce sync.Once
}
// NewAsyncReadCloser returns a ReadCloser that supports Close during Read.
func NewAsyncReadCloser(r io.ReadCloser) io.ReadCloser {
ret := &asyncReadCloser{
ReadCloser: r,
read: make(chan struct{}),
reads: make(chan asyncRead),
close: make(chan struct{}),
}
go ret.loop()
return ret
}
func (r *asyncReadCloser) loop() {
buf := make([]byte, 1024*8)
var n int
var err error
for {
select {
case <-r.read:
n = 0
if err == nil {
n, err = r.ReadCloser.Read(buf)
}
select {
case r.reads <- asyncRead{buf[:n], err}:
case <-r.close:
}
case <-r.close:
r.ReadCloser.Close()
return
}
}
}
func (r *asyncReadCloser) Read(b []byte) (int, error) {
if r.buffer.Len() > 0 {
return r.buffer.Read(b)
}
select {
case r.read <- struct{}{}:
case <-r.close:
}
select {
case read := <-r.reads:
if read.err != nil {
return 0, read.err
}
n := copy(b, read.b)
if n < len(read.b) {
r.buffer.Write(read.b[n:])
}
return n, nil
case <-r.close:
return 0, io.EOF
}
}
func (r *asyncReadCloser) Close() error {
r.closeOnce.Do(func() {
close(r.close)
})
return nil
}
type asyncWrite struct {
n int
err error
}
type asyncWriteCloser struct {
io.WriteCloser
writeBuffer bytes.Buffer
write chan struct{}
writes chan asyncWrite
close chan struct{}
closeOnce sync.Once
}
// NewAsyncWriteCloser returns a WriteCloser that supports Close during Write.
func NewAsyncWriteCloser(w io.WriteCloser) io.WriteCloser {
ret := &asyncWriteCloser{
WriteCloser: w,
write: make(chan struct{}),
writes: make(chan asyncWrite),
close: make(chan struct{}),
}
go ret.loop()
return ret
}
func (w *asyncWriteCloser) loop() {
var n int64
var err error
for {
select {
case <-w.write:
n = 0
if err == nil {
n, err = w.writeBuffer.WriteTo(w.WriteCloser)
}
select {
case w.writes <- asyncWrite{int(n), err}:
case <-w.close:
}
case <-w.close:
w.WriteCloser.Close()
return
}
}
}
func (w *asyncWriteCloser) Write(b []byte) (int, error) {
if n, err := w.writeBuffer.Write(b); err != nil {
return n, err
}
select {
case w.write <- struct{}{}:
case <-w.close:
}
select {
case write := <-w.writes:
return write.n, write.err
case <-w.close:
return 0, io.EOF
}
}
func (w *asyncWriteCloser) Close() error {
w.closeOnce.Do(func() {
close(w.close)
})
return nil
}
type rwc struct {
io.ReadCloser
io.WriteCloser
}
func (rwc *rwc) Close() (err error) {
if f, ok := rwc.ReadCloser.(*os.File); ok {
// https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion
err = os.NewFile(f.Fd(), "").Close()
} else {
err = rwc.ReadCloser.Close()
}
werr := rwc.WriteCloser.Close()
if err == nil {
err = werr
err = rwc.WriteCloser.Close()
if rerr := rwc.ReadCloser.Close(); err == nil {
err = rerr
}
return
}

View File

@@ -0,0 +1,73 @@
package rpcplugin
import (
"io/ioutil"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAsyncReadCloser(t *testing.T) {
rf, w, err := os.Pipe()
require.NoError(t, err)
r := NewAsyncReadCloser(rf)
defer r.Close()
go func() {
w.Write([]byte("foo"))
w.Close()
}()
foo, err := ioutil.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, "foo", string(foo))
}
func TestNewAsyncReadCloser_CloseDuringRead(t *testing.T) {
rf, w, err := os.Pipe()
require.NoError(t, err)
defer w.Close()
r := NewAsyncReadCloser(rf)
go func() {
time.Sleep(time.Millisecond * 200)
r.Close()
}()
r.Read(make([]byte, 10))
}
func TestNewAsyncWriteCloser(t *testing.T) {
r, wf, err := os.Pipe()
require.NoError(t, err)
w := NewAsyncWriteCloser(wf)
defer w.Close()
go func() {
foo, err := ioutil.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, "foo", string(foo))
r.Close()
}()
n, err := w.Write([]byte("foo"))
require.NoError(t, err)
assert.Equal(t, 3, n)
}
func TestNewAsyncWriteCloser_CloseDuringWrite(t *testing.T) {
r, wf, err := os.Pipe()
require.NoError(t, err)
defer r.Close()
w := NewAsyncWriteCloser(wf)
go func() {
time.Sleep(time.Millisecond * 200)
w.Close()
}()
w.Write(make([]byte, 10))
}

View File

@@ -19,7 +19,7 @@ func NewIPC() (io.ReadWriteCloser, []*os.File, error) {
childWriter.Close()
return nil, nil, err
}
return NewReadWriteCloser(parentReader, parentWriter), []*os.File{childReader, childWriter}, nil
return NewReadWriteCloser(NewAsyncReadCloser(parentReader), NewAsyncWriteCloser(parentWriter)), []*os.File{childReader, childWriter}, nil
}
// Returns the IPC instance inherited by the process from its parent.

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"path/filepath"
"strings"
"sync/atomic"
"time"
@@ -123,7 +124,11 @@ func SupervisorProvider(bundle *model.BundleInfo) (plugin.Supervisor, error) {
} else if bundle.Manifest.Backend == nil || bundle.Manifest.Backend.Executable == "" {
return nil, fmt.Errorf("no backend executable specified")
}
executable := filepath.Clean(filepath.Join(".", bundle.Manifest.Backend.Executable))
if strings.HasPrefix(executable, "..") {
return nil, fmt.Errorf("invalid backend executable")
}
return &Supervisor{
executable: filepath.Join(bundle.Path, bundle.Manifest.Backend.Executable),
executable: filepath.Join(bundle.Path, executable),
}, nil
}

View File

@@ -43,6 +43,19 @@ func TestSupervisor(t *testing.T) {
require.NoError(t, supervisor.Stop())
}
func TestSupervisor_InvalidExecutablePath(t *testing.T) {
dir, err := ioutil.TempDir("", "")
require.NoError(t, err)
defer os.RemoveAll(dir)
ioutil.WriteFile(filepath.Join(dir, "plugin.json"), []byte(`{"id": "foo", "backend": {"executable": "/foo/../../backend.exe"}}`), 0600)
bundle := model.BundleInfoForPath(dir)
supervisor, err := SupervisorProvider(bundle)
assert.Nil(t, supervisor)
assert.Error(t, err)
}
// If plugin development goes really wrong, let's make sure plugin activation won't block forever.
func TestSupervisor_StartTimeout(t *testing.T) {
dir, err := ioutil.TempDir("", "")

View File

@@ -13,19 +13,16 @@ import (
)
// ExtractTarGz takes in an io.Reader containing the bytes for a .tar.gz file and
// a destination string to extract to. A list of the file and directory names that
// were extracted is returned.
func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) {
// a destination string to extract to.
func ExtractTarGz(gzipStream io.Reader, dst string) error {
uncompressedStream, err := gzip.NewReader(gzipStream)
if err != nil {
return nil, fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error())
}
defer uncompressedStream.Close()
tarReader := tar.NewReader(uncompressedStream)
filenames := []string{}
for true {
header, err := tarReader.Next()
@@ -34,50 +31,46 @@ func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) {
}
if err != nil {
return nil, fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error())
}
switch header.Typeflag {
case tar.TypeDir:
if PathTraversesUpward(header.Name) {
return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
}
path := filepath.Join(dst, header.Name)
if err := os.Mkdir(path, 0744); err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error())
}
filenames = append(filenames, header.Name)
case tar.TypeReg:
if PathTraversesUpward(header.Name) {
return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
}
path := filepath.Join(dst, header.Name)
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0744); err != nil {
return nil, fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error())
}
outFile, err := os.Create(path)
outFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return nil, fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error())
}
defer outFile.Close()
if _, err := io.Copy(outFile, tarReader); err != nil {
return nil, fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error())
}
filenames = append(filenames, header.Name)
default:
return nil, fmt.Errorf(
return fmt.Errorf(
"ExtractTarGz: unknown type: %v in %v",
header.Typeflag,
header.Name)
}
}
return filenames, nil
return nil
}