mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
PLT-7407: Back-end plugins (#7409)
* tie back-end plugins together * fix comment typo * add tests and a bit of polish * tests and polish * add test, don't let backend executable paths escape the plugin directory
This commit is contained in:
@@ -150,7 +150,7 @@ func InitApi(full bool) {
|
||||
BaseRoutes.PublicFile = BaseRoutes.Root.PathPrefix("/files/{file_id:[A-Za-z0-9]+}/public").Subrouter()
|
||||
|
||||
BaseRoutes.Plugins = BaseRoutes.ApiRoot.PathPrefix("/plugins").Subrouter()
|
||||
BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-]+}").Subrouter()
|
||||
BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}").Subrouter()
|
||||
|
||||
BaseRoutes.Commands = BaseRoutes.ApiRoot.PathPrefix("/commands").Subrouter()
|
||||
BaseRoutes.Command = BaseRoutes.Commands.PathPrefix("/{command_id:[A-Za-z0-9]+}").Subrouter()
|
||||
|
||||
@@ -110,5 +110,5 @@ func TestPlugin(t *testing.T) {
|
||||
_, resp = th.SystemAdminClient.RemovePlugin("bad.id")
|
||||
CheckNotFoundStatus(t, resp)
|
||||
|
||||
th.App.Srv.PluginEnv = nil
|
||||
th.App.PluginEnv = nil
|
||||
}
|
||||
|
||||
@@ -6,10 +6,13 @@ package app
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/mattermost/mattermost-server/plugin/pluginenv"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
Srv *Server
|
||||
Srv *Server
|
||||
PluginEnv *pluginenv.Environment
|
||||
}
|
||||
|
||||
var globalApp App
|
||||
|
||||
227
app/plugins.go
227
app/plugins.go
@@ -4,6 +4,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -19,14 +20,17 @@ import (
|
||||
"github.com/mattermost/mattermost-server/model"
|
||||
"github.com/mattermost/mattermost-server/utils"
|
||||
|
||||
"github.com/mattermost/mattermost-server/app/plugin"
|
||||
builtinplugin "github.com/mattermost/mattermost-server/app/plugin"
|
||||
"github.com/mattermost/mattermost-server/app/plugin/jira"
|
||||
"github.com/mattermost/mattermost-server/app/plugin/ldapextras"
|
||||
|
||||
"github.com/mattermost/mattermost-server/plugin"
|
||||
"github.com/mattermost/mattermost-server/plugin/pluginenv"
|
||||
)
|
||||
|
||||
type PluginAPI struct {
|
||||
id string
|
||||
router *mux.Router
|
||||
id string
|
||||
app *App
|
||||
}
|
||||
|
||||
func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error {
|
||||
@@ -37,37 +41,67 @@ func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (api *PluginAPI) PluginRouter() *mux.Router {
|
||||
return api.router
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) {
|
||||
return Global().GetTeamByName(name)
|
||||
return api.app.GetTeamByName(name)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetUserByName(name string) (*model.User, *model.AppError) {
|
||||
return Global().GetUserByUsername(name)
|
||||
func (api *PluginAPI) GetUserByUsername(name string) (*model.User, *model.AppError) {
|
||||
return api.app.GetUserByUsername(name)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) {
|
||||
return Global().GetChannelByName(name, teamId)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
|
||||
return Global().GetDirectChannel(userId1, userId2)
|
||||
func (api *PluginAPI) GetChannelByName(name, teamId string) (*model.Channel, *model.AppError) {
|
||||
return api.app.GetChannelByName(name, teamId)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) {
|
||||
return Global().CreatePostMissingChannel(post, true)
|
||||
return api.app.CreatePostMissingChannel(post, true)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) {
|
||||
type BuiltInPluginAPI struct {
|
||||
id string
|
||||
router *mux.Router
|
||||
app *App
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) LoadPluginConfiguration(dest interface{}) error {
|
||||
if b, err := json.Marshal(utils.Cfg.PluginSettings.Plugins[api.id]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return json.Unmarshal(b, dest)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) PluginRouter() *mux.Router {
|
||||
return api.router
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) {
|
||||
return api.app.GetTeamByName(name)
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) GetUserByName(name string) (*model.User, *model.AppError) {
|
||||
return api.app.GetUserByUsername(name)
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) {
|
||||
return api.app.GetChannelByName(name, teamId)
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
|
||||
return api.app.GetDirectChannel(userId1, userId2)
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) {
|
||||
return api.app.CreatePostMissingChannel(post, true)
|
||||
}
|
||||
|
||||
func (api *BuiltInPluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) {
|
||||
ldapInterface := einterfaces.GetLdapInterface()
|
||||
if ldapInterface == nil {
|
||||
return nil, model.NewAppError("GetLdapUserAttributes", "ent.ldap.disabled.app_error", nil, "", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
user, err := Global().GetUser(userId)
|
||||
user, err := api.app.GetUser(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,7 +109,7 @@ func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string)
|
||||
return ldapInterface.GetUserAttributes(*user.AuthData, attributes)
|
||||
}
|
||||
|
||||
func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) {
|
||||
func (api *BuiltInPluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) {
|
||||
token := ""
|
||||
isTokenFromQueryString := false
|
||||
|
||||
@@ -111,7 +145,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m
|
||||
return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
session, err := Global().GetSession(token)
|
||||
session, err := api.app.GetSession(token)
|
||||
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized)
|
||||
@@ -122,7 +156,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (api *PluginAPI) I18n(id string, r *http.Request) string {
|
||||
func (api *BuiltInPluginAPI) I18n(id string, r *http.Request) string {
|
||||
if r != nil {
|
||||
f, _ := utils.GetTranslationsAndLocale(nil, r)
|
||||
return f(id)
|
||||
@@ -131,16 +165,17 @@ func (api *PluginAPI) I18n(id string, r *http.Request) string {
|
||||
return f(id)
|
||||
}
|
||||
|
||||
func (a *App) InitPlugins() {
|
||||
plugins := map[string]plugin.Plugin{
|
||||
func (a *App) InitBuiltInPlugins() {
|
||||
plugins := map[string]builtinplugin.Plugin{
|
||||
"jira": &jira.Plugin{},
|
||||
"ldapextras": &ldapextras.Plugin{},
|
||||
}
|
||||
for id, p := range plugins {
|
||||
l4g.Info("Initializing plugin: " + id)
|
||||
api := &PluginAPI{
|
||||
api := &BuiltInPluginAPI{
|
||||
id: id,
|
||||
router: a.Srv.Router.PathPrefix("/plugins/" + id).Subrouter(),
|
||||
app: a,
|
||||
}
|
||||
p.Initialize(api)
|
||||
}
|
||||
@@ -155,19 +190,19 @@ func (a *App) InitPlugins() {
|
||||
}
|
||||
|
||||
func (a *App) ActivatePlugins() {
|
||||
if a.Srv.PluginEnv == nil {
|
||||
if a.PluginEnv == nil {
|
||||
l4g.Error("plugin env not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
plugins, err := a.Srv.PluginEnv.Plugins()
|
||||
plugins, err := a.PluginEnv.Plugins()
|
||||
if err != nil {
|
||||
l4g.Error("failed to start up plugins: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for _, plugin := range plugins {
|
||||
err := a.Srv.PluginEnv.ActivatePlugin(plugin.Manifest.Id)
|
||||
err := a.PluginEnv.ActivatePlugin(plugin.Manifest.Id)
|
||||
if err != nil {
|
||||
l4g.Error(err.Error())
|
||||
}
|
||||
@@ -176,48 +211,43 @@ func (a *App) ActivatePlugins() {
|
||||
}
|
||||
|
||||
func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *model.AppError) {
|
||||
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
tmpDir, err := ioutil.TempDir("", "plugintmp")
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.temp_dir.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
defer func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
}()
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
filenames, err := utils.ExtractTarGz(pluginFile, tmpDir)
|
||||
if err != nil {
|
||||
if err := utils.ExtractTarGz(pluginFile, tmpDir); err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.extract.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if len(filenames) == 0 {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.no_files.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
tmpPluginDir := tmpDir
|
||||
dir, err := ioutil.ReadDir(tmpDir)
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
splitPath := strings.Split(filenames[0], string(os.PathSeparator))
|
||||
|
||||
if len(splitPath) == 0 {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.bad_path.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
if len(dir) == 1 && dir[0].IsDir() {
|
||||
tmpPluginDir = filepath.Join(tmpPluginDir, dir[0].Name())
|
||||
}
|
||||
|
||||
manifestDir := filepath.Join(tmpDir, splitPath[0])
|
||||
|
||||
manifest, _, err := model.FindManifest(manifestDir)
|
||||
manifest, _, err := model.FindManifest(tmpPluginDir)
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.manifest.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
os.Rename(manifestDir, filepath.Join(a.Srv.PluginEnv.SearchPath(), manifest.Id))
|
||||
os.Rename(tmpPluginDir, filepath.Join(a.PluginEnv.SearchPath(), manifest.Id))
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.mvdir.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Should add manifest validation and error handling here
|
||||
|
||||
err = a.Srv.PluginEnv.ActivatePlugin(manifest.Id)
|
||||
err = a.PluginEnv.ActivatePlugin(manifest.Id)
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.activate.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
@@ -226,11 +256,11 @@ func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *m
|
||||
}
|
||||
|
||||
func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) {
|
||||
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
plugins, err := a.Srv.PluginEnv.ActivePlugins()
|
||||
plugins, err := a.PluginEnv.ActivePlugins()
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.get_plugins.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -244,16 +274,16 @@ func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) {
|
||||
}
|
||||
|
||||
func (a *App) RemovePlugin(id string) *model.AppError {
|
||||
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
return model.NewAppError("RemovePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
err := a.Srv.PluginEnv.DeactivatePlugin(id)
|
||||
err := a.PluginEnv.DeactivatePlugin(id)
|
||||
if err != nil {
|
||||
return model.NewAppError("RemovePlugin", "app.plugin.deactivate.app_error", nil, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = os.RemoveAll(filepath.Join(a.Srv.PluginEnv.SearchPath(), id))
|
||||
err = os.RemoveAll(filepath.Join(a.PluginEnv.SearchPath(), id))
|
||||
if err != nil {
|
||||
return model.NewAppError("RemovePlugin", "app.plugin.remove.app_error", nil, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -268,11 +298,11 @@ type ClientConfigPlugin struct {
|
||||
}
|
||||
|
||||
func (a *App) GetPluginsForClientConfig() string {
|
||||
if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable {
|
||||
return ""
|
||||
}
|
||||
|
||||
plugins, err := a.Srv.PluginEnv.ActivePlugins()
|
||||
plugins, err := a.PluginEnv.ActivePlugins()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -292,3 +322,94 @@ func (a *App) GetPluginsForClientConfig() string {
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (a *App) InitPlugins(pluginPath, webappPath string) {
|
||||
a.InitBuiltInPlugins()
|
||||
|
||||
if !utils.IsLicensed() || !*utils.License().Features.FutureFeatures || !*utils.Cfg.PluginSettings.Enable {
|
||||
return
|
||||
}
|
||||
|
||||
l4g.Info("Starting up plugins")
|
||||
|
||||
err := os.Mkdir(pluginPath, 0744)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
l4g.Error("failed to start up plugins: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
a.PluginEnv, err = pluginenv.New(
|
||||
pluginenv.SearchPath(pluginPath),
|
||||
pluginenv.WebappPath(webappPath),
|
||||
pluginenv.APIProvider(func(m *model.Manifest) (plugin.API, error) {
|
||||
return &PluginAPI{
|
||||
id: m.Id,
|
||||
app: a,
|
||||
}, nil
|
||||
}),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
l4g.Error("failed to start up plugins: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
utils.AddConfigListener(func(_, _ *model.Config) {
|
||||
for _, err := range a.PluginEnv.Hooks().OnConfigurationChange() {
|
||||
l4g.Error(err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}", a.ServePluginRequest)
|
||||
a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}/{anything:.*}", a.ServePluginRequest)
|
||||
|
||||
a.ActivatePlugins()
|
||||
}
|
||||
|
||||
func (a *App) ServePluginRequest(w http.ResponseWriter, r *http.Request) {
|
||||
token := ""
|
||||
|
||||
authHeader := r.Header.Get(model.HEADER_AUTH)
|
||||
if strings.HasPrefix(strings.ToUpper(authHeader), model.HEADER_BEARER+":") {
|
||||
token = authHeader[len(model.HEADER_BEARER)+1:]
|
||||
} else if strings.HasPrefix(strings.ToLower(authHeader), model.HEADER_TOKEN+":") {
|
||||
token = authHeader[len(model.HEADER_TOKEN)+1:]
|
||||
} else if cookie, _ := r.Cookie(model.SESSION_COOKIE_TOKEN); cookie != nil && (r.Method == "GET" || r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML) {
|
||||
token = cookie.Value
|
||||
} else {
|
||||
token = r.URL.Query().Get("access_token")
|
||||
}
|
||||
|
||||
r.Header.Del("Mattermost-User-Id")
|
||||
if token != "" {
|
||||
if session, err := a.GetSession(token); err != nil {
|
||||
r.Header.Set("Mattermost-User-Id", session.UserId)
|
||||
}
|
||||
}
|
||||
|
||||
cookies := r.Cookies()
|
||||
r.Header.Del("Cookie")
|
||||
for _, c := range cookies {
|
||||
if c.Name != model.SESSION_COOKIE_TOKEN {
|
||||
r.AddCookie(c)
|
||||
}
|
||||
}
|
||||
r.Header.Del(model.HEADER_AUTH)
|
||||
r.Header.Del("Referer")
|
||||
|
||||
newQuery := r.URL.Query()
|
||||
newQuery.Del("access_token")
|
||||
r.URL.RawQuery = newQuery.Encode()
|
||||
|
||||
params := mux.Vars(r)
|
||||
a.PluginEnv.Hooks().ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "plugin_id", params["plugin_id"])))
|
||||
}
|
||||
|
||||
func (a *App) ShutDownPlugins() {
|
||||
if a.PluginEnv == nil {
|
||||
return
|
||||
}
|
||||
for _, err := range a.PluginEnv.Shutdown() {
|
||||
l4g.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
"gopkg.in/throttled/throttled.v2/store/memstore"
|
||||
|
||||
"github.com/mattermost/mattermost-server/model"
|
||||
"github.com/mattermost/mattermost-server/plugin/pluginenv"
|
||||
"github.com/mattermost/mattermost-server/store"
|
||||
"github.com/mattermost/mattermost-server/utils"
|
||||
)
|
||||
@@ -30,7 +28,6 @@ type Server struct {
|
||||
WebSocketRouter *WebSocketRouter
|
||||
Router *mux.Router
|
||||
GracefulServer *graceful.Server
|
||||
PluginEnv *pluginenv.Environment
|
||||
}
|
||||
|
||||
var allowedMethods []string = []string{
|
||||
@@ -187,10 +184,6 @@ func (a *App) StartServer() {
|
||||
}()
|
||||
}
|
||||
|
||||
if utils.IsLicensed() && *utils.License().Features.FutureFeatures && *utils.Cfg.PluginSettings.Enable {
|
||||
a.StartupPlugins("plugins", "webapp/dist")
|
||||
}
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
if *utils.Cfg.ServiceSettings.ConnectionSecurity == model.CONN_SECURITY_TLS {
|
||||
@@ -226,30 +219,7 @@ func (a *App) StopServer() {
|
||||
a.Srv.Store.Close()
|
||||
HubStop()
|
||||
|
||||
a.ShutDownPlugins()
|
||||
|
||||
l4g.Info(utils.T("api.server.stop_server.stopped.info"))
|
||||
}
|
||||
|
||||
func (a *App) StartupPlugins(pluginPath, webappPath string) {
|
||||
l4g.Info("Starting up plugins")
|
||||
|
||||
err := os.Mkdir(pluginPath, 0744)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
err = nil
|
||||
} else {
|
||||
l4g.Error("failed to start up plugins: " + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
a.Srv.PluginEnv, err = pluginenv.New(
|
||||
pluginenv.SearchPath(pluginPath),
|
||||
pluginenv.WebappPath(webappPath),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
l4g.Error("failed to start up plugins: " + err.Error())
|
||||
}
|
||||
|
||||
a.ActivatePlugins()
|
||||
}
|
||||
|
||||
@@ -71,20 +71,22 @@ func runServer(configFileLocation string) {
|
||||
l4g.Error("Problem with file storage settings: " + err.Error())
|
||||
}
|
||||
|
||||
app.Global().NewServer()
|
||||
app.Global().InitStores()
|
||||
a := app.Global()
|
||||
a.NewServer()
|
||||
a.InitStores()
|
||||
api.InitRouter()
|
||||
|
||||
if model.BuildEnterpriseReady == "true" {
|
||||
a.LoadLicense()
|
||||
}
|
||||
a.InitPlugins("plugins", "webapp/dist")
|
||||
|
||||
wsapi.InitRouter()
|
||||
api4.InitApi(false)
|
||||
api.InitApi()
|
||||
app.Global().InitPlugins()
|
||||
wsapi.InitApi()
|
||||
web.InitWeb()
|
||||
|
||||
if model.BuildEnterpriseReady == "true" {
|
||||
app.Global().LoadLicense()
|
||||
}
|
||||
|
||||
if !utils.IsLicensed() && len(utils.Cfg.SqlSettings.DataSourceReplicas) > 1 {
|
||||
l4g.Warn(utils.T("store.sql.read_replicas_not_licensed.critical"))
|
||||
utils.Cfg.SqlSettings.DataSourceReplicas = utils.Cfg.SqlSettings.DataSourceReplicas[:1]
|
||||
@@ -98,7 +100,7 @@ func runServer(configFileLocation string) {
|
||||
|
||||
resetStatuses()
|
||||
|
||||
app.Global().StartServer()
|
||||
a.StartServer()
|
||||
|
||||
// If we allow testing then listen for manual testing URL hits
|
||||
if utils.Cfg.ServiceSettings.EnableTesting {
|
||||
@@ -118,7 +120,7 @@ func runServer(configFileLocation string) {
|
||||
}
|
||||
|
||||
if einterfaces.GetClusterInterface() != nil {
|
||||
app.Global().RegisterAllClusterMessageHandlers()
|
||||
a.RegisterAllClusterMessageHandlers()
|
||||
einterfaces.GetClusterInterface().StartInterNodeCommunication()
|
||||
}
|
||||
|
||||
@@ -132,7 +134,7 @@ func runServer(configFileLocation string) {
|
||||
}
|
||||
}
|
||||
|
||||
jobs.Srv.Store = app.Global().Srv.Store
|
||||
jobs.Srv.Store = a.Srv.Store
|
||||
if *utils.Cfg.JobSettings.RunJobs {
|
||||
jobs.Srv.StartWorkers()
|
||||
}
|
||||
@@ -157,7 +159,7 @@ func runServer(configFileLocation string) {
|
||||
jobs.Srv.StopSchedulers()
|
||||
jobs.Srv.StopWorkers()
|
||||
|
||||
app.Global().StopServer()
|
||||
a.StopServer()
|
||||
}
|
||||
|
||||
func runSecurityJob() {
|
||||
|
||||
12
i18n/en.json
12
i18n/en.json
@@ -3427,10 +3427,6 @@
|
||||
"id": "app.plugin.activate.app_error",
|
||||
"translation": "Unable to activate extracted plugin. Plugin may already exist and be activated."
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.bad_path.app_error",
|
||||
"translation": "Bad file path in extracted files"
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.deactivate.app_error",
|
||||
"translation": "Unable to deactivate plugin"
|
||||
@@ -3443,6 +3439,10 @@
|
||||
"id": "app.plugin.extract.app_error",
|
||||
"translation": "Encountered error extracting plugin"
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.filesystem.app_error",
|
||||
"translation": "Encountered filesystem error"
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.get_plugins.app_error",
|
||||
"translation": "Unable to get active plugins"
|
||||
@@ -3455,10 +3455,6 @@
|
||||
"id": "app.plugin.mvdir.app_error",
|
||||
"translation": "Unable to move plugin from temporary directory to final destination"
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.no_files.app_error",
|
||||
"translation": "No files found in the compressed folder"
|
||||
},
|
||||
{
|
||||
"id": "app.plugin.remove.app_error",
|
||||
"translation": "Unable to delete plugin"
|
||||
|
||||
@@ -12,6 +12,9 @@ type Hooks interface {
|
||||
// use the API, and the plugin will be terminated shortly after this invocation.
|
||||
OnDeactivate() error
|
||||
|
||||
// OnConfigurationChange is invoked when configuration changes may have been made.
|
||||
OnConfigurationChange() error
|
||||
|
||||
// ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for
|
||||
// the /plugins/{id} path will be routed to the plugin.
|
||||
//
|
||||
|
||||
@@ -4,6 +4,7 @@ package pluginenv
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -27,7 +28,7 @@ type Environment struct {
|
||||
apiProvider APIProviderFunc
|
||||
supervisorProvider SupervisorProviderFunc
|
||||
activePlugins map[string]ActivePlugin
|
||||
mutex sync.Mutex
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type Option func(*Environment)
|
||||
@@ -61,15 +62,13 @@ func (env *Environment) SearchPath() string {
|
||||
|
||||
// Returns a list of all plugins found within the environment.
|
||||
func (env *Environment) Plugins() ([]*model.BundleInfo, error) {
|
||||
env.mutex.Lock()
|
||||
defer env.mutex.Unlock()
|
||||
return ScanSearchPath(env.searchPath)
|
||||
}
|
||||
|
||||
// Returns a list of all currently active plugins within the environment.
|
||||
func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
|
||||
env.mutex.Lock()
|
||||
defer env.mutex.Unlock()
|
||||
env.mutex.RLock()
|
||||
defer env.mutex.RUnlock()
|
||||
|
||||
activePlugins := []*model.BundleInfo{}
|
||||
for _, p := range env.activePlugins {
|
||||
@@ -81,8 +80,8 @@ func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
|
||||
|
||||
// Returns the ids of the currently active plugins.
|
||||
func (env *Environment) ActivePluginIds() (ids []string) {
|
||||
env.mutex.Lock()
|
||||
defer env.mutex.Unlock()
|
||||
env.mutex.RLock()
|
||||
defer env.mutex.RUnlock()
|
||||
|
||||
for id := range env.activePlugins {
|
||||
ids = append(ids, id)
|
||||
@@ -200,13 +199,55 @@ func (env *Environment) Shutdown() (errs []error) {
|
||||
for _, activePlugin := range env.activePlugins {
|
||||
if activePlugin.Supervisor != nil {
|
||||
if err := activePlugin.Supervisor.Hooks().OnDeactivate(); err != nil {
|
||||
errs = append(errs, err)
|
||||
errs = append(errs, errors.Wrapf(err, "OnDeactivate() error for %v", activePlugin.BundleInfo.Manifest.Id))
|
||||
}
|
||||
if err := activePlugin.Supervisor.Stop(); err != nil {
|
||||
errs = append(errs, err)
|
||||
errs = append(errs, errors.Wrapf(err, "error stopping supervisor for %v", activePlugin.BundleInfo.Manifest.Id))
|
||||
}
|
||||
}
|
||||
}
|
||||
env.activePlugins = make(map[string]ActivePlugin)
|
||||
return
|
||||
}
|
||||
|
||||
type EnvironmentHooks struct {
|
||||
env *Environment
|
||||
}
|
||||
|
||||
func (env *Environment) Hooks() *EnvironmentHooks {
|
||||
return &EnvironmentHooks{env}
|
||||
}
|
||||
|
||||
// OnConfigurationChange invokes the OnConfigurationChange hook for all plugins. Any errors
|
||||
// encountered will be returned.
|
||||
func (h *EnvironmentHooks) OnConfigurationChange() (errs []error) {
|
||||
h.env.mutex.RLock()
|
||||
defer h.env.mutex.RUnlock()
|
||||
for _, activePlugin := range h.env.activePlugins {
|
||||
if activePlugin.Supervisor == nil {
|
||||
continue
|
||||
}
|
||||
if err := activePlugin.Supervisor.Hooks().OnConfigurationChange(); err != nil {
|
||||
errs = append(errs, errors.Wrapf(err, "OnConfigurationChange error for %v", activePlugin.BundleInfo.Manifest.Id))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ServeHTTP invokes the ServeHTTP hook for the plugin identified by the request or responds with a
|
||||
// 404 not found.
|
||||
//
|
||||
// It expects the request's context to have a plugin_id set.
|
||||
func (h *EnvironmentHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if id := r.Context().Value("plugin_id"); id != nil {
|
||||
if idstr, ok := id.(string); ok {
|
||||
h.env.mutex.RLock()
|
||||
defer h.env.mutex.RUnlock()
|
||||
if plugin, ok := h.env.activePlugins[idstr]; ok && plugin.Supervisor != nil {
|
||||
plugin.Supervisor.Hooks().ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package pluginenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -298,3 +302,70 @@ func TestEnvironment_ShutdownError(t *testing.T) {
|
||||
assert.Equal(t, env.ActivePluginIds(), []string{"foo"})
|
||||
assert.Len(t, env.Shutdown(), 2)
|
||||
}
|
||||
|
||||
func TestEnvironment_ConcurrentHookInvocations(t *testing.T) {
|
||||
dir := initTmpDir(t, map[string]string{
|
||||
"foo/plugin.json": `{"id": "foo", "backend": {}}`,
|
||||
})
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var provider MockProvider
|
||||
defer provider.AssertExpectations(t)
|
||||
|
||||
var api struct{ plugin.API }
|
||||
var supervisor MockSupervisor
|
||||
defer supervisor.AssertExpectations(t)
|
||||
var hooks plugintest.Hooks
|
||||
defer hooks.AssertExpectations(t)
|
||||
|
||||
env, err := New(
|
||||
SearchPath(dir),
|
||||
APIProvider(provider.API),
|
||||
SupervisorProvider(provider.Supervisor),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer env.Shutdown()
|
||||
|
||||
provider.On("API").Return(&api, nil)
|
||||
provider.On("Supervisor").Return(&supervisor, nil)
|
||||
|
||||
supervisor.On("Start").Return(nil)
|
||||
supervisor.On("Stop").Return(nil)
|
||||
supervisor.On("Hooks").Return(&hooks)
|
||||
|
||||
ch := make(chan bool)
|
||||
|
||||
hooks.On("OnActivate", &api).Return(nil)
|
||||
hooks.On("OnDeactivate").Return(nil)
|
||||
hooks.On("ServeHTTP", mock.AnythingOfType("*httptest.ResponseRecorder"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
|
||||
r := args.Get(1).(*http.Request)
|
||||
if r.URL.Path == "/1" {
|
||||
<-ch
|
||||
} else {
|
||||
ch <- true
|
||||
}
|
||||
})
|
||||
|
||||
assert.NoError(t, env.ActivatePlugin("foo"))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
req, err := http.NewRequest("GET", "/1", nil)
|
||||
require.NoError(t, err)
|
||||
env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
req, err := http.NewRequest("GET", "/2", nil)
|
||||
require.NoError(t, err)
|
||||
env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -22,6 +22,10 @@ func (m *Hooks) OnDeactivate() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *Hooks) OnConfigurationChange() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.Called(w, r)
|
||||
}
|
||||
|
||||
@@ -86,6 +86,15 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (h *LocalHooks) OnConfigurationChange(args, reply *struct{}) error {
|
||||
if hook, ok := h.hooks.(interface {
|
||||
OnConfigurationChange() error
|
||||
}); ok {
|
||||
return hook.OnConfigurationChange()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ServeHTTPArgs struct {
|
||||
ResponseWriterStream int64
|
||||
Request *http.Request
|
||||
@@ -122,11 +131,14 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) {
|
||||
server.ServeConn(conn)
|
||||
}
|
||||
|
||||
// These assignments are part of the wire protocol. You can add more, but should not change existing
|
||||
// assignments.
|
||||
const (
|
||||
remoteOnActivate = iota
|
||||
remoteOnDeactivate
|
||||
remoteServeHTTP
|
||||
maxRemoteHookCount
|
||||
remoteOnActivate = 0
|
||||
remoteOnDeactivate = 1
|
||||
remoteServeHTTP = 2
|
||||
remoteOnConfigurationChange = 3
|
||||
maxRemoteHookCount = iota
|
||||
)
|
||||
|
||||
type RemoteHooks struct {
|
||||
@@ -164,6 +176,13 @@ func (h *RemoteHooks) OnDeactivate() error {
|
||||
return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil)
|
||||
}
|
||||
|
||||
func (h *RemoteHooks) OnConfigurationChange() error {
|
||||
if !h.implemented[remoteOnConfigurationChange] {
|
||||
return nil
|
||||
}
|
||||
return h.client.Call("LocalHooks.OnConfigurationChange", struct{}{}, nil)
|
||||
}
|
||||
|
||||
func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.implemented[remoteServeHTTP] {
|
||||
http.NotFound(w, r)
|
||||
@@ -227,6 +246,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) {
|
||||
remote.implemented[remoteOnActivate] = true
|
||||
case "OnDeactivate":
|
||||
remote.implemented[remoteOnDeactivate] = true
|
||||
case "OnConfigurationChange":
|
||||
remote.implemented[remoteOnConfigurationChange] = true
|
||||
case "ServeHTTP":
|
||||
remote.implemented[remoteServeHTTP] = true
|
||||
}
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mattermost/mattermost-server/plugin"
|
||||
"github.com/mattermost/mattermost-server/plugin/plugintest"
|
||||
@@ -50,6 +52,9 @@ func TestHooks(t *testing.T) {
|
||||
hooks.On("OnDeactivate").Return(nil)
|
||||
assert.NoError(t, remote.OnDeactivate())
|
||||
|
||||
hooks.On("OnConfigurationChange").Return(nil)
|
||||
assert.NoError(t, remote.OnConfigurationChange())
|
||||
|
||||
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(http.ResponseWriter)
|
||||
r := args.Get(1).(*http.Request)
|
||||
@@ -77,6 +82,45 @@ func TestHooks(t *testing.T) {
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHooks_Concurrency(t *testing.T) {
|
||||
var hooks plugintest.Hooks
|
||||
defer hooks.AssertExpectations(t)
|
||||
|
||||
assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
|
||||
ch := make(chan bool)
|
||||
|
||||
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
|
||||
r := args.Get(1).(*http.Request)
|
||||
if r.URL.Path == "/1" {
|
||||
<-ch
|
||||
} else {
|
||||
ch <- true
|
||||
}
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
req, err := http.NewRequest("GET", "/1", nil)
|
||||
require.NoError(t, err)
|
||||
remote.ServeHTTP(rec, req)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
req, err := http.NewRequest("GET", "/2", nil)
|
||||
require.NoError(t, err)
|
||||
remote.ServeHTTP(rec, req)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}))
|
||||
}
|
||||
|
||||
type testHooks struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
@@ -2,26 +2,169 @@ package rpcplugin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type asyncRead struct {
|
||||
b []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type asyncReadCloser struct {
|
||||
io.ReadCloser
|
||||
buffer bytes.Buffer
|
||||
read chan struct{}
|
||||
reads chan asyncRead
|
||||
close chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewAsyncReadCloser returns a ReadCloser that supports Close during Read.
|
||||
func NewAsyncReadCloser(r io.ReadCloser) io.ReadCloser {
|
||||
ret := &asyncReadCloser{
|
||||
ReadCloser: r,
|
||||
read: make(chan struct{}),
|
||||
reads: make(chan asyncRead),
|
||||
close: make(chan struct{}),
|
||||
}
|
||||
go ret.loop()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (r *asyncReadCloser) loop() {
|
||||
buf := make([]byte, 1024*8)
|
||||
var n int
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-r.read:
|
||||
n = 0
|
||||
if err == nil {
|
||||
n, err = r.ReadCloser.Read(buf)
|
||||
}
|
||||
select {
|
||||
case r.reads <- asyncRead{buf[:n], err}:
|
||||
case <-r.close:
|
||||
}
|
||||
case <-r.close:
|
||||
r.ReadCloser.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *asyncReadCloser) Read(b []byte) (int, error) {
|
||||
if r.buffer.Len() > 0 {
|
||||
return r.buffer.Read(b)
|
||||
}
|
||||
select {
|
||||
case r.read <- struct{}{}:
|
||||
case <-r.close:
|
||||
}
|
||||
select {
|
||||
case read := <-r.reads:
|
||||
if read.err != nil {
|
||||
return 0, read.err
|
||||
}
|
||||
n := copy(b, read.b)
|
||||
if n < len(read.b) {
|
||||
r.buffer.Write(read.b[n:])
|
||||
}
|
||||
return n, nil
|
||||
case <-r.close:
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func (r *asyncReadCloser) Close() error {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.close)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
type asyncWrite struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
type asyncWriteCloser struct {
|
||||
io.WriteCloser
|
||||
writeBuffer bytes.Buffer
|
||||
write chan struct{}
|
||||
writes chan asyncWrite
|
||||
close chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewAsyncWriteCloser returns a WriteCloser that supports Close during Write.
|
||||
func NewAsyncWriteCloser(w io.WriteCloser) io.WriteCloser {
|
||||
ret := &asyncWriteCloser{
|
||||
WriteCloser: w,
|
||||
write: make(chan struct{}),
|
||||
writes: make(chan asyncWrite),
|
||||
close: make(chan struct{}),
|
||||
}
|
||||
go ret.loop()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (w *asyncWriteCloser) loop() {
|
||||
var n int64
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-w.write:
|
||||
n = 0
|
||||
if err == nil {
|
||||
n, err = w.writeBuffer.WriteTo(w.WriteCloser)
|
||||
}
|
||||
select {
|
||||
case w.writes <- asyncWrite{int(n), err}:
|
||||
case <-w.close:
|
||||
}
|
||||
case <-w.close:
|
||||
w.WriteCloser.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncWriteCloser) Write(b []byte) (int, error) {
|
||||
if n, err := w.writeBuffer.Write(b); err != nil {
|
||||
return n, err
|
||||
}
|
||||
select {
|
||||
case w.write <- struct{}{}:
|
||||
case <-w.close:
|
||||
}
|
||||
select {
|
||||
case write := <-w.writes:
|
||||
return write.n, write.err
|
||||
case <-w.close:
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncWriteCloser) Close() error {
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.close)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
type rwc struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func (rwc *rwc) Close() (err error) {
|
||||
if f, ok := rwc.ReadCloser.(*os.File); ok {
|
||||
// https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion
|
||||
err = os.NewFile(f.Fd(), "").Close()
|
||||
} else {
|
||||
err = rwc.ReadCloser.Close()
|
||||
}
|
||||
werr := rwc.WriteCloser.Close()
|
||||
if err == nil {
|
||||
err = werr
|
||||
err = rwc.WriteCloser.Close()
|
||||
if rerr := rwc.ReadCloser.Close(); err == nil {
|
||||
err = rerr
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
73
plugin/rpcplugin/io_test.go
Normal file
73
plugin/rpcplugin/io_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package rpcplugin
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAsyncReadCloser(t *testing.T) {
|
||||
rf, w, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
r := NewAsyncReadCloser(rf)
|
||||
defer r.Close()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte("foo"))
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
foo, err := ioutil.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "foo", string(foo))
|
||||
}
|
||||
|
||||
func TestNewAsyncReadCloser_CloseDuringRead(t *testing.T) {
|
||||
rf, w, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
defer w.Close()
|
||||
|
||||
r := NewAsyncReadCloser(rf)
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
r.Close()
|
||||
}()
|
||||
r.Read(make([]byte, 10))
|
||||
}
|
||||
|
||||
func TestNewAsyncWriteCloser(t *testing.T) {
|
||||
r, wf, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
w := NewAsyncWriteCloser(wf)
|
||||
defer w.Close()
|
||||
|
||||
go func() {
|
||||
foo, err := ioutil.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "foo", string(foo))
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
n, err := w.Write([]byte("foo"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
}
|
||||
|
||||
func TestNewAsyncWriteCloser_CloseDuringWrite(t *testing.T) {
|
||||
r, wf, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
w := NewAsyncWriteCloser(wf)
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
w.Close()
|
||||
}()
|
||||
w.Write(make([]byte, 10))
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func NewIPC() (io.ReadWriteCloser, []*os.File, error) {
|
||||
childWriter.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return NewReadWriteCloser(parentReader, parentWriter), []*os.File{childReader, childWriter}, nil
|
||||
return NewReadWriteCloser(NewAsyncReadCloser(parentReader), NewAsyncWriteCloser(parentWriter)), []*os.File{childReader, childWriter}, nil
|
||||
}
|
||||
|
||||
// Returns the IPC instance inherited by the process from its parent.
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -123,7 +124,11 @@ func SupervisorProvider(bundle *model.BundleInfo) (plugin.Supervisor, error) {
|
||||
} else if bundle.Manifest.Backend == nil || bundle.Manifest.Backend.Executable == "" {
|
||||
return nil, fmt.Errorf("no backend executable specified")
|
||||
}
|
||||
executable := filepath.Clean(filepath.Join(".", bundle.Manifest.Backend.Executable))
|
||||
if strings.HasPrefix(executable, "..") {
|
||||
return nil, fmt.Errorf("invalid backend executable")
|
||||
}
|
||||
return &Supervisor{
|
||||
executable: filepath.Join(bundle.Path, bundle.Manifest.Backend.Executable),
|
||||
executable: filepath.Join(bundle.Path, executable),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -43,6 +43,19 @@ func TestSupervisor(t *testing.T) {
|
||||
require.NoError(t, supervisor.Stop())
|
||||
}
|
||||
|
||||
func TestSupervisor_InvalidExecutablePath(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
ioutil.WriteFile(filepath.Join(dir, "plugin.json"), []byte(`{"id": "foo", "backend": {"executable": "/foo/../../backend.exe"}}`), 0600)
|
||||
|
||||
bundle := model.BundleInfoForPath(dir)
|
||||
supervisor, err := SupervisorProvider(bundle)
|
||||
assert.Nil(t, supervisor)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// If plugin development goes really wrong, let's make sure plugin activation won't block forever.
|
||||
func TestSupervisor_StartTimeout(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "")
|
||||
|
||||
@@ -13,19 +13,16 @@ import (
|
||||
)
|
||||
|
||||
// ExtractTarGz takes in an io.Reader containing the bytes for a .tar.gz file and
|
||||
// a destination string to extract to. A list of the file and directory names that
|
||||
// were extracted is returned.
|
||||
func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) {
|
||||
// a destination string to extract to.
|
||||
func ExtractTarGz(gzipStream io.Reader, dst string) error {
|
||||
uncompressedStream, err := gzip.NewReader(gzipStream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error())
|
||||
}
|
||||
defer uncompressedStream.Close()
|
||||
|
||||
tarReader := tar.NewReader(uncompressedStream)
|
||||
|
||||
filenames := []string{}
|
||||
|
||||
for true {
|
||||
header, err := tarReader.Next()
|
||||
|
||||
@@ -34,50 +31,46 @@ func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error())
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if PathTraversesUpward(header.Name) {
|
||||
return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
|
||||
return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
|
||||
}
|
||||
|
||||
path := filepath.Join(dst, header.Name)
|
||||
if err := os.Mkdir(path, 0744); err != nil && !os.IsExist(err) {
|
||||
return nil, fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error())
|
||||
}
|
||||
|
||||
filenames = append(filenames, header.Name)
|
||||
case tar.TypeReg:
|
||||
if PathTraversesUpward(header.Name) {
|
||||
return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
|
||||
return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards")
|
||||
}
|
||||
|
||||
path := filepath.Join(dst, header.Name)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
if err := os.MkdirAll(dir, 0744); err != nil {
|
||||
return nil, fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error())
|
||||
}
|
||||
|
||||
outFile, err := os.Create(path)
|
||||
outFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error())
|
||||
}
|
||||
defer outFile.Close()
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
return nil, fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error())
|
||||
return fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error())
|
||||
}
|
||||
|
||||
filenames = append(filenames, header.Name)
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
return fmt.Errorf(
|
||||
"ExtractTarGz: unknown type: %v in %v",
|
||||
header.Typeflag,
|
||||
header.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return filenames, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user