Files
mattermost/plugin/rpcplugin/hooks_test.go
Christopher Speller df6a7f8b19 MM-10249 Adding plugin ability to intercept posts before they reach the DB. (#8791)
* Adding plugin ability to intercept posts before they reach the DB.

* s/envoked/invoked/
2018-05-15 13:33:47 -07:00

238 lines
6.3 KiB
Go

package rpcplugin
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/plugin"
"github.com/mattermost/mattermost-server/plugin/plugintest"
)
func testHooksRPC(hooks interface{}, f func(*RemoteHooks)) error {
r1, w1 := io.Pipe()
r2, w2 := io.Pipe()
c1 := NewMuxer(NewReadWriteCloser(r1, w2), false)
defer c1.Close()
c2 := NewMuxer(NewReadWriteCloser(r2, w1), true)
defer c2.Close()
id, server := c1.Serve()
go ServeHooks(hooks, server, c1)
remote, err := ConnectHooks(c2.Connect(id), c2, "plugin_id")
if err != nil {
return err
}
defer remote.Close()
f(remote)
return nil
}
func TestHooks(t *testing.T) {
var api plugintest.API
var hooks plugintest.Hooks
defer hooks.AssertExpectations(t)
assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
hooks.On("OnActivate", mock.AnythingOfType("*rpcplugin.RemoteAPI")).Return(nil)
assert.NoError(t, remote.OnActivate(&api))
hooks.On("OnDeactivate").Return(nil)
assert.NoError(t, remote.OnDeactivate())
hooks.On("OnConfigurationChange").Return(nil)
assert.NoError(t, remote.OnConfigurationChange())
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
w := args.Get(0).(http.ResponseWriter)
r := args.Get(1).(*http.Request)
assert.Equal(t, "/foo", r.URL.Path)
assert.Equal(t, "POST", r.Method)
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, "asdf", string(body))
assert.Equal(t, "header", r.Header.Get("Test-Header"))
w.Write([]byte("bar"))
})
w := httptest.NewRecorder()
r, err := http.NewRequest("POST", "/foo", strings.NewReader("asdf"))
r.Header.Set("Test-Header", "header")
assert.NoError(t, err)
remote.ServeHTTP(w, r)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "bar", string(body))
hooks.On("ExecuteCommand", &model.CommandArgs{
Command: "/foo",
}).Return(&model.CommandResponse{
Text: "bar",
}, nil)
commandResponse, appErr := hooks.ExecuteCommand(&model.CommandArgs{
Command: "/foo",
})
assert.Equal(t, "bar", commandResponse.Text)
assert.Nil(t, appErr)
hooks.On("MessageWillBePosted", mock.AnythingOfType("*model.Post")).Return(func(post *model.Post) *model.Post {
post.Message += "_testing"
return post
}, "changemessage")
post, changemessage := remote.MessageWillBePosted(&model.Post{Id: "1", Message: "base"})
assert.Equal(t, "changemessage", changemessage)
assert.Equal(t, "base_testing", post.Message)
assert.Equal(t, "1", post.Id)
hooks.On("MessageWillBeUpdated", mock.AnythingOfType("*model.Post"), mock.AnythingOfType("*model.Post")).Return(func(newPost, oldPost *model.Post) *model.Post {
newPost.Message += "_testing"
return newPost
}, "changemessage2")
post2, changemessage2 := remote.MessageWillBeUpdated(&model.Post{Id: "2", Message: "base2"}, &model.Post{Id: "OLD", Message: "OLDMESSAGE"})
assert.Equal(t, "changemessage2", changemessage2)
assert.Equal(t, "base2_testing", post2.Message)
assert.Equal(t, "2", post2.Id)
hooks.On("MessageHasBeenPosted", mock.AnythingOfType("*model.Post")).Return(nil)
remote.MessageHasBeenPosted(&model.Post{})
hooks.On("MessageHasBeenUpdated", mock.AnythingOfType("*model.Post"), mock.AnythingOfType("*model.Post")).Return(nil)
remote.MessageHasBeenUpdated(&model.Post{}, &model.Post{})
}))
}
func TestHooks_Concurrency(t *testing.T) {
var hooks plugintest.Hooks
defer hooks.AssertExpectations(t)
assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
ch := make(chan bool)
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
r := args.Get(1).(*http.Request)
if r.URL.Path == "/1" {
<-ch
} else {
ch <- true
}
})
rec := httptest.NewRecorder()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
req, err := http.NewRequest("GET", "/1", nil)
require.NoError(t, err)
remote.ServeHTTP(rec, req)
wg.Done()
}()
go func() {
req, err := http.NewRequest("GET", "/2", nil)
require.NoError(t, err)
remote.ServeHTTP(rec, req)
wg.Done()
}()
wg.Wait()
}))
}
type testHooks struct {
mock.Mock
}
func (h *testHooks) OnActivate(api plugin.API) error {
return h.Called(api).Error(0)
}
func TestHooks_PartiallyImplemented(t *testing.T) {
var api plugintest.API
var hooks testHooks
defer hooks.AssertExpectations(t)
assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
implemented, err := remote.Implemented()
assert.NoError(t, err)
assert.Equal(t, []string{"OnActivate"}, implemented)
hooks.On("OnActivate", mock.AnythingOfType("*rpcplugin.RemoteAPI")).Return(nil)
assert.NoError(t, remote.OnActivate(&api))
assert.NoError(t, remote.OnDeactivate())
}))
}
type benchmarkHooks struct{}
func (*benchmarkHooks) OnDeactivate() error { return nil }
func (*benchmarkHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ioutil.ReadAll(r.Body)
w.Header().Set("Foo-Header", "foo")
http.Error(w, "foo", http.StatusBadRequest)
}
func BenchmarkHooks_OnDeactivate(b *testing.B) {
var hooks benchmarkHooks
if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
remote.OnDeactivate()
}
b.StopTimer()
}); err != nil {
b.Fatal(err.Error())
}
}
func BenchmarkHooks_ServeHTTP(b *testing.B) {
var hooks benchmarkHooks
if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/foo", strings.NewReader("12345678901234567890"))
remote.ServeHTTP(w, r)
}
b.StopTimer()
}); err != nil {
b.Fatal(err.Error())
}
}
func BenchmarkHooks_Unimplemented(b *testing.B) {
var hooks testHooks
if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
remote.OnDeactivate()
}
b.StopTimer()
}); err != nil {
b.Fatal(err.Error())
}
}